mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Code Clean] Replace std::runtime_error with TORCH_CHECK (#165209)
Including: 1. `aten/src/ATen/core` 2. `c10/core` Fixes part of #148114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165209 Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
parent
2b748d0a56
commit
12aac12b8d
|
|
@ -59,9 +59,7 @@ struct TORCH_API Generator {
|
|||
|
||||
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
|
||||
: impl_(std::move(gen_impl)) {
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
|
||||
}
|
||||
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
|
||||
}
|
||||
|
||||
bool operator==(const Generator& rhs) const {
|
||||
|
|
|
|||
|
|
@ -111,9 +111,7 @@ class TORCH_API TensorBase {
|
|||
explicit TensorBase(
|
||||
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
|
||||
: impl_(std::move(tensor_impl)) {
|
||||
if (impl_.get() == nullptr) {
|
||||
throw std::runtime_error("TensorImpl with nullptr is not supported");
|
||||
}
|
||||
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
|
||||
}
|
||||
TensorBase(const TensorBase&) = default;
|
||||
TensorBase(TensorBase&&) noexcept = default;
|
||||
|
|
|
|||
|
|
@ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) {
|
|||
return it->second;
|
||||
|
||||
auto pos = s.find("::");
|
||||
if (pos == std::string::npos) {
|
||||
std::stringstream ss;
|
||||
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
|
||||
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
|
||||
|
||||
Symbol sym(sym_to_info_.size());
|
||||
|
|
@ -121,12 +117,7 @@ std::string Symbol::domainString() const {
|
|||
}
|
||||
|
||||
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
|
||||
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
|
||||
std::ostringstream ss;
|
||||
ss << "Symbol: domain string is expected to be prefixed with '"
|
||||
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
|
||||
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
|
||||
return fromQualString(qualString);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
|
@ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) {
|
|||
case Tag::Enum:
|
||||
case Tag::Stream:
|
||||
case Tag::Uninitialized:
|
||||
throw std::runtime_error(
|
||||
TORCH_CHECK(false,
|
||||
"unhashable type: '" + v.type()->repr_str() + "'");
|
||||
}
|
||||
// the above switch should be exhaustive
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <ATen/core/type_factory.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <optional>
|
||||
#include <c10/core/SymFloat.h>
|
||||
#include <c10/core/SymBool.h>
|
||||
|
|
@ -116,10 +117,8 @@ struct SingleElementType : public SharedType {
|
|||
|
||||
protected:
|
||||
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
|
||||
if (!this->elem) {
|
||||
throw std::runtime_error(c10::str(
|
||||
TORCH_CHECK(this->elem, c10::str(
|
||||
"Can not create ", typeKindToString(Kind), " with None type"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -416,16 +415,12 @@ struct TORCH_API SymbolicShape {
|
|||
}
|
||||
|
||||
ShapeSymbol operator[](size_t i) const {
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
ShapeSymbol at(size_t i) const {
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
|
|
@ -520,9 +515,7 @@ struct VaryingShape {
|
|||
}
|
||||
|
||||
const std::optional<T> &operator[](size_t i) const {
|
||||
if (!dims_) {
|
||||
throw std::runtime_error("Rank isn't fixed");
|
||||
}
|
||||
TORCH_CHECK(dims_, "Rank isn't fixed");
|
||||
return (*dims_).at(i);
|
||||
}
|
||||
|
||||
|
|
@ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType {
|
|||
|
||||
TypePtr createWithContained(
|
||||
std::vector<TypePtr> contained_types) const override {
|
||||
if (contained_types.size() != 2) {
|
||||
throw std::runtime_error("Expected 2 contained types");
|
||||
}
|
||||
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
|
||||
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <array>
|
||||
|
|
@ -826,9 +827,7 @@ TupleType::TupleType(
|
|||
: NamedType(TypeKind::TupleType, std::move(name)),
|
||||
elements_(std::move(elements)),
|
||||
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
|
||||
if (!v) {
|
||||
throw std::runtime_error("Can not create tuple with None type");
|
||||
}
|
||||
TORCH_CHECK(v, "Can not create tuple with None type");
|
||||
return v->hasFreeVariables();
|
||||
})), schema_(std::move(schema)) {
|
||||
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
|
|||
case Backend::PrivateUse1:
|
||||
return DispatchKey::PrivateUse1;
|
||||
default:
|
||||
throw std::runtime_error("Unknown backend");
|
||||
TORCH_CHECK(false, "Unknown backend");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -336,7 +336,7 @@ class C10_API Scalar {
|
|||
} else if (isBoolean()) {
|
||||
return ScalarType::Bool;
|
||||
} else {
|
||||
throw std::runtime_error("Unknown scalar type.");
|
||||
TORCH_CHECK(false, "Unknown scalar type.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
|
|||
case c10::ScalarType::Float4_e2m1fn_x2:
|
||||
return std::make_pair("float4_e2m1fn_x2", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
TORCH_CHECK(false, "Unimplemented scalar type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -87,9 +87,7 @@ bool ThreadPool::inThreadPool() const {
|
|||
}
|
||||
|
||||
void ThreadPool::run(std::function<void()> func) {
|
||||
if (threads_.empty()) {
|
||||
throw std::runtime_error("No threads to run a task");
|
||||
}
|
||||
TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
// Set task and signal condition variable so that a worker thread will
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user