[Code Clean] Replace std::runtime_error with TORCH_CHECK (#165209)

Including:
1. `aten/src/ATen/core`
2. `c10/core`

Fixes part of #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165209
Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
KarhouTam 2025-10-22 00:05:16 +00:00 committed by PyTorch MergeBot
parent 2b748d0a56
commit 12aac12b8d
10 changed files with 18 additions and 42 deletions

View File

@ -59,9 +59,7 @@ struct TORCH_API Generator {
explicit Generator(c10::intrusive_ptr<c10::GeneratorImpl> gen_impl)
: impl_(std::move(gen_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("GeneratorImpl with nullptr is not supported");
}
TORCH_CHECK(impl_.get(), "GeneratorImpl with nullptr is not supported");
}
bool operator==(const Generator& rhs) const {

View File

@ -111,9 +111,7 @@ class TORCH_API TensorBase {
explicit TensorBase(
c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
: impl_(std::move(tensor_impl)) {
if (impl_.get() == nullptr) {
throw std::runtime_error("TensorImpl with nullptr is not supported");
}
TORCH_CHECK(impl_.get(), "TensorImpl with nullptr is not supported");
}
TensorBase(const TensorBase&) = default;
TensorBase(TensorBase&&) noexcept = default;

View File

@ -68,11 +68,7 @@ Symbol InternedStrings::_symbol(const std::string& s) {
return it->second;
auto pos = s.find("::");
if (pos == std::string::npos) {
std::stringstream ss;
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
throw std::runtime_error(ss.str());
}
TORCH_CHECK(pos != std::string::npos, "all symbols must have a namespace, <namespace>::<string>, but found: ", s);
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
Symbol sym(sym_to_info_.size());
@ -121,12 +117,7 @@ std::string Symbol::domainString() const {
}
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
std::ostringstream ss;
ss << "Symbol: domain string is expected to be prefixed with '"
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
throw std::runtime_error(ss.str());
}
TORCH_CHECK(d.compare(0, domain_prefix().size(), domain_prefix()) == 0, "Symbol: domain string is expected to be prefixed with '", domain_prefix(), "', e.g. 'org.pytorch.aten'");
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
return fromQualString(qualString);
}

View File

@ -7,6 +7,7 @@
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
@ -412,7 +413,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Enum:
case Tag::Stream:
case Tag::Uninitialized:
throw std::runtime_error(
TORCH_CHECK(false,
"unhashable type: '" + v.type()->repr_str() + "'");
}
// the above switch should be exhaustive

View File

@ -8,6 +8,7 @@
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h>
#include <c10/util/Exception.h>
#include <optional>
#include <c10/core/SymFloat.h>
#include <c10/core/SymBool.h>
@ -116,10 +117,8 @@ struct SingleElementType : public SharedType {
protected:
SingleElementType(TypePtr elem) : SharedType(Kind), elem(std::move(elem)) {
if (!this->elem) {
throw std::runtime_error(c10::str(
TORCH_CHECK(this->elem, c10::str(
"Can not create ", typeKindToString(Kind), " with None type"));
}
}
private:
@ -416,16 +415,12 @@ struct TORCH_API SymbolicShape {
}
ShapeSymbol operator[](size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
ShapeSymbol at(size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
@ -520,9 +515,7 @@ struct VaryingShape {
}
const std::optional<T> &operator[](size_t i) const {
if (!dims_) {
throw std::runtime_error("Rank isn't fixed");
}
TORCH_CHECK(dims_, "Rank isn't fixed");
return (*dims_).at(i);
}
@ -957,9 +950,7 @@ struct TORCH_API DictType : public SharedType {
TypePtr createWithContained(
std::vector<TypePtr> contained_types) const override {
if (contained_types.size() != 2) {
throw std::runtime_error("Expected 2 contained types");
}
TORCH_CHECK(contained_types.size() == 2, "Expected 2 contained types");
return create(std::move(contained_types.at(0)), std::move(contained_types.at(1)));
}

View File

@ -8,6 +8,7 @@
#include <ATen/core/jit_type.h>
#include <c10/macros/Macros.h>
#include <c10/util/env.h>
#include <c10/util/Exception.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <array>
@ -826,9 +827,7 @@ TupleType::TupleType(
: NamedType(TypeKind::TupleType, std::move(name)),
elements_(std::move(elements)),
has_free_variables_(std::any_of(elements_.begin(), elements_.end(), [](const TypePtr& v) {
if (!v) {
throw std::runtime_error("Can not create tuple with None type");
}
TORCH_CHECK(v, "Can not create tuple with None type");
return v->hasFreeVariables();
})), schema_(std::move(schema)) {

View File

@ -223,7 +223,7 @@ inline DispatchKey backendToDispatchKey(Backend b) {
case Backend::PrivateUse1:
return DispatchKey::PrivateUse1;
default:
throw std::runtime_error("Unknown backend");
TORCH_CHECK(false, "Unknown backend");
}
}

View File

@ -336,7 +336,7 @@ class C10_API Scalar {
} else if (isBoolean()) {
return ScalarType::Bool;
} else {
throw std::runtime_error("Unknown scalar type.");
TORCH_CHECK(false, "Unknown scalar type.");
}
}

View File

@ -228,7 +228,7 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
case c10::ScalarType::Float4_e2m1fn_x2:
return std::make_pair("float4_e2m1fn_x2", "");
default:
throw std::runtime_error("Unimplemented scalar type");
TORCH_CHECK(false, "Unimplemented scalar type");
}
}

View File

@ -87,9 +87,7 @@ bool ThreadPool::inThreadPool() const {
}
void ThreadPool::run(std::function<void()> func) {
if (threads_.empty()) {
throw std::runtime_error("No threads to run a task");
}
TORCH_CHECK(threads_.size() > 0, "No threads to run a task");
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will