mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
217 lines
5.1 KiB
C++
217 lines
5.1 KiB
C++
#include <c10/util/Backtrace.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <c10/util/Type.h>
|
|
|
|
#include <iostream>
|
|
#include <numeric>
|
|
#include <sstream>
|
|
#include <string>
|
|
|
|
namespace c10 {
|
|
|
|
Error::Error(std::string msg, std::string backtrace, const void* caller)
|
|
: msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
|
|
refresh_what();
|
|
}
|
|
|
|
// PyTorch-style error message
|
|
// Error::Error(SourceLocation source_location, const std::string& msg)
|
|
// NB: This is defined in Logging.cpp for access to GetFetchStackTrace
|
|
|
|
// Caffe2-style error message
|
|
Error::Error(
|
|
const char* file,
|
|
const uint32_t line,
|
|
const char* condition,
|
|
const std::string& msg,
|
|
const std::string& backtrace,
|
|
const void* caller)
|
|
: Error(
|
|
str("[enforce fail at ",
|
|
detail::StripBasename(file),
|
|
":",
|
|
line,
|
|
"] ",
|
|
condition,
|
|
". ",
|
|
msg),
|
|
backtrace,
|
|
caller) {}
|
|
|
|
std::string Error::compute_what(bool include_backtrace) const {
|
|
std::ostringstream oss;
|
|
|
|
oss << msg_;
|
|
|
|
if (context_.size() == 1) {
|
|
// Fold error and context in one line
|
|
oss << " (" << context_[0] << ")";
|
|
} else {
|
|
for (const auto& c : context_) {
|
|
oss << "\n " << c;
|
|
}
|
|
}
|
|
|
|
if (include_backtrace) {
|
|
oss << "\n" << backtrace_;
|
|
}
|
|
|
|
return oss.str();
|
|
}
|
|
|
|
void Error::refresh_what() {
|
|
what_ = compute_what(/*include_backtrace*/ true);
|
|
what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
|
|
}
|
|
|
|
void Error::add_context(std::string new_msg) {
|
|
context_.push_back(std::move(new_msg));
|
|
// TODO: Calling add_context O(n) times has O(n^2) cost. We can fix
|
|
// this perf problem by populating the fields lazily... if this ever
|
|
// actually is a problem.
|
|
// NB: If you do fix this, make sure you do it in a thread safe way!
|
|
// what() is almost certainly expected to be thread safe even when
|
|
// accessed across multiple threads
|
|
refresh_what();
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
void torchCheckFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const std::string& msg) {
|
|
throw ::c10::Error({func, file, line}, msg);
|
|
}
|
|
|
|
void torchCheckFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const char* msg) {
|
|
throw ::c10::Error({func, file, line}, msg);
|
|
}
|
|
|
|
void torchInternalAssertFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const char* condMsg,
|
|
const char* userMsg) {
|
|
torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
|
|
}
|
|
|
|
// This should never be called. It is provided in case of compilers
|
|
// that don't do any dead code stripping in debug builds.
|
|
void torchInternalAssertFail(
|
|
const char* func,
|
|
const char* file,
|
|
uint32_t line,
|
|
const char* condMsg,
|
|
const std::string& userMsg) {
|
|
torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
namespace Warning {
|
|
|
|
namespace {
|
|
WarningHandler* getBaseHandler() {
|
|
static WarningHandler base_warning_handler_ = WarningHandler();
|
|
return &base_warning_handler_;
|
|
};
|
|
|
|
class ThreadWarningHandler {
|
|
public:
|
|
ThreadWarningHandler() = delete;
|
|
|
|
static WarningHandler* get_handler() {
|
|
if (!warning_handler_) {
|
|
warning_handler_ = getBaseHandler();
|
|
}
|
|
return warning_handler_;
|
|
}
|
|
|
|
static void set_handler(WarningHandler* handler) {
|
|
warning_handler_ = handler;
|
|
}
|
|
|
|
private:
|
|
static thread_local WarningHandler* warning_handler_;
|
|
};
|
|
|
|
thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
|
|
|
|
} // namespace
|
|
|
|
void warn(
|
|
const SourceLocation& source_location,
|
|
const std::string& msg,
|
|
const bool verbatim) {
|
|
ThreadWarningHandler::get_handler()->process(source_location, msg, verbatim);
|
|
}
|
|
|
|
void warn(
|
|
SourceLocation source_location,
|
|
detail::CompileTimeEmptyString msg,
|
|
const bool verbatim) {
|
|
warn(source_location, "", verbatim);
|
|
}
|
|
|
|
void warn(
|
|
SourceLocation source_location,
|
|
const char* msg,
|
|
const bool verbatim) {
|
|
ThreadWarningHandler::get_handler()->process(source_location, msg, verbatim);
|
|
}
|
|
|
|
void set_warning_handler(WarningHandler* handler) noexcept(true) {
|
|
ThreadWarningHandler::set_handler(handler);
|
|
}
|
|
|
|
WarningHandler* get_warning_handler() noexcept(true) {
|
|
return ThreadWarningHandler::get_handler();
|
|
}
|
|
|
|
bool warn_always = false;
|
|
|
|
void set_warnAlways(bool setting) noexcept(true) {
|
|
warn_always = setting;
|
|
}
|
|
|
|
bool get_warnAlways() noexcept(true) {
|
|
return warn_always;
|
|
}
|
|
|
|
WarnAlways::WarnAlways(bool setting /*=true*/)
|
|
: prev_setting(get_warnAlways()) {
|
|
set_warnAlways(setting);
|
|
}
|
|
|
|
WarnAlways::~WarnAlways() {
|
|
set_warnAlways(prev_setting);
|
|
}
|
|
|
|
} // namespace Warning
|
|
|
|
void WarningHandler::process(
|
|
const SourceLocation& source_location,
|
|
const std::string& msg,
|
|
const bool /*verbatim*/) {
|
|
LOG_AT_FILE_LINE(WARNING, source_location.file, source_location.line)
|
|
<< "Warning: " << msg << " (function " << source_location.function << ")";
|
|
}
|
|
|
|
std::string GetExceptionString(const std::exception& e) {
|
|
#ifdef __GXX_RTTI
|
|
return demangle(typeid(e).name()) + ": " + e.what();
|
|
#else
|
|
return std::string("Exception (no RTTI available): ") + e.what();
|
|
#endif // __GXX_RTTI
|
|
}
|
|
|
|
} // namespace c10
|