[Code Clean] Replace std::runtime_error with TORCH_CHECK (#163437)

Replace the runtime_error of the vallina C++ exceptions with TORCH_CEHCK
Including:
- torch/csrc/export
- torch/csrc/cuda

Fixes #148114
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163437
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
zhudada 2025-10-12 01:23:02 +00:00 committed by PyTorch MergeBot
parent bb0635d7dd
commit 058814794b
7 changed files with 84 additions and 101 deletions

View File

@ -4,6 +4,7 @@
#include <ATen/native/ConvUtils.h>
#include <c10/core/Device.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/UniqueVoidPtr.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/utils/python_arg_parser.h>
@ -861,7 +862,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
case TraceEntry::SEGMENT_MAP:
return segment_map_s;
}
throw std::runtime_error("unreachable");
TORCH_CHECK(false, "unreachable");
};
for (const auto& traceInfo : snapshot.device_traces) {

View File

@ -1,6 +1,7 @@
#include <ATen/Context.h>
#include <ATen/record_function.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Exception.h>
#include <torch/csrc/cuda/memory_snapshot.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/serialization/pickler.h>
@ -413,7 +414,7 @@ std::string _memory_snapshot_pickled() {
case TraceEntry::SEGMENT_MAP:
return segment_map_s;
}
throw std::runtime_error("unreachable");
TORCH_CHECK(false, "unreachable");
};
for (const auto& traceInfo : snapshot.device_traces) {

View File

@ -62,7 +62,7 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
case torch::cuda::nccl::ncclResult::NumResults:
return ncclResult_t::ncclNumResults;
default:
throw std::runtime_error("Unconvertible NCCL type");
TORCH_CHECK(false, "Unconvertible NCCL type");
}
}
@ -91,7 +91,7 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
case ncclNumResults:
return torch::cuda::nccl::ncclResult::NumResults;
default:
throw std::runtime_error("Unconvertible NCCL type");
TORCH_CHECK(false, "Unconvertible NCCL type");
}
}
@ -194,10 +194,9 @@ static void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error(
"NCCL timeout when waiting for nonblocking call to become successful.");
}
TORCH_CHECK(
timeElapsed <= nccl_nonblocking_timeout(),
"NCCL timeout when waiting for nonblocking call to become successful.");
sched_yield(); // yield to other threads
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
}
@ -227,10 +226,9 @@ static void NCCL_CHECK_TIMEOUT(
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
currentTimepoint - startTimepoint)
.count();
if (timeElapsed > nccl_nonblocking_timeout()) {
throw std::runtime_error(
"NCCL timeout when waiting for nonblocking call to become successful.");
}
TORCH_CHECK(
timeElapsed <= nccl_nonblocking_timeout(),
"NCCL timeout when waiting for nonblocking call to become successful.");
sched_yield(); // yield to other threads
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
} while (result == ncclInProgress);
@ -258,7 +256,7 @@ void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
std::ostringstream err;
err << "NCCL Error " << static_cast<int>(status) << ": "
<< ncclGetErrorString(to_nccl_result(status));
throw std::runtime_error(err.str());
TORCH_CHECK(false, err.str());
}
struct NcclCommList {
@ -318,41 +316,36 @@ static void check_tensor(
int64_t ref_numel,
ScalarType ref_dtype) {
auto check_one = [&](const at::Tensor& tensor) {
if (!tensor.is_cuda() || tensor.is_sparse()) {
throw std::runtime_error(
"input and output elements have to be cuda dense Tensors");
}
TORCH_CHECK(
tensor.is_cuda() && !tensor.is_sparse(),
"input and output elements have to be cuda dense Tensors");
if (ref_dtype != tensor.scalar_type()) {
throw std::runtime_error(
"all inputs and outputs must be of the same Tensor dtype");
}
TORCH_CHECK(
ref_dtype == tensor.scalar_type(),
"all inputs and outputs must be of the same Tensor dtype");
if (!tensor.is_contiguous()) {
throw std::runtime_error("all inputs and outputs have to be contiguous");
}
TORCH_CHECK(
tensor.is_contiguous(), "all inputs and outputs have to be contiguous");
};
check_one(input);
// all inputs must be same size
if (input.numel() != ref_numel) {
throw std::runtime_error(
"all inputs must have the same number of elements");
}
TORCH_CHECK(
input.numel() == ref_numel,
"all inputs must have the same number of elements");
if (output) {
check_one(*output);
// inputs and outputs must be on same device respectively
if (input.get_device() != output->get_device()) {
throw std::runtime_error("input and output must be on the same device");
}
TORCH_CHECK(
input.get_device() == output->get_device(),
"input and output must be on the same device");
if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
throw std::runtime_error(
"output must be of size input_size * size_multiplier");
}
TORCH_CHECK(
output->numel() * output_multiplier == ref_numel * input_multiplier,
"output must be of size input_size * size_multiplier");
}
}
@ -364,15 +357,13 @@ void check_inputs(
// len(inputs) == len(outputs)
size_t len = inputs.size();
if (len == 0) {
throw std::runtime_error("input sequence can't be empty");
}
TORCH_CHECK(len != 0, "input sequence can't be empty");
if (len != outputs.size()) {
std::stringstream err;
err << "inputs and outputs sequences have to be of the same length, but got input of length "
<< len << " and output of length " << outputs.size();
throw std::runtime_error(err.str());
TORCH_CHECK(false, err.str());
}
device_set devices;
@ -388,9 +379,8 @@ void check_inputs(
auto input_device = input.get_device();
// inputs must be on unique devices
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices");
}
TORCH_CHECK(
!devices.test(input_device), "inputs must be on unique devices");
devices.set(input_device);
}
}
@ -403,9 +393,7 @@ void check_inputs(
int output_multiplier) {
auto len = inputs.size();
if (len <= 0) {
throw std::runtime_error("input sequence can't be empty");
}
TORCH_CHECK(len > 0, "input sequence can't be empty");
device_set devices;
int64_t numel = inputs[0].numel();
@ -426,9 +414,8 @@ void check_inputs(
auto input_device = input.get_device();
// inputs must be on unique devices
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices");
}
TORCH_CHECK(
!devices.test(input_device), "inputs must be on unique devices");
devices.set(input_device);
}
}

View File

@ -11,6 +11,7 @@
#include <torch/csrc/utils/pybind.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
using namespace at;
@ -63,10 +64,9 @@ static std::vector<std::optional<at::cuda::CUDAStream>> unpack_streams(
return std::vector<std::optional<at::cuda::CUDAStream>>(size, std::nullopt);
}
auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
if (streams.size() != size) {
throw std::runtime_error(
"number of streams is not equal to number of inputs");
}
TORCH_CHECK(
streams.size() == size,
"number of streams is not equal to number of inputs");
return streams;
}
@ -90,10 +90,9 @@ static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
}
}
if (comms.size() != size) {
throw std::runtime_error(
"number of communicators is not equal to number of inputs");
}
TORCH_CHECK(
comms.size() == size,
"number of communicators is not equal to number of inputs");
return comms;
}

View File

@ -1,3 +1,4 @@
#include <c10/util/Exception.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/python_headers.h>
@ -8,18 +9,17 @@
// whatever the current stream of the device the input is associated with was.
std::vector<std::optional<at::cuda::CUDAStream>>
THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
if (!PySequence_Check(obj)) {
throw std::runtime_error(
"Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
}
TORCH_CHECK(
PySequence_Check(obj),
"Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
if (seq.get() == nullptr) {
throw std::runtime_error(
"expected PySequence, but got " + std::string(THPUtils_typename(obj)));
}
TORCH_CHECK(
seq.get() != nullptr,
"expected PySequence, but got " + std::string(THPUtils_typename(obj)));
std::vector<std::optional<at::cuda::CUDAStream>> streams;
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
streams.reserve(length);
for (Py_ssize_t i = 0; i < length; i++) {
PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i);
@ -34,7 +34,8 @@ THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
} else if (stream == Py_None) {
streams.emplace_back();
} else {
throw std::runtime_error(
TORCH_CHECK(
false,
"Unknown data type found in stream list. Need torch.cuda.Stream or None");
}
}

View File

@ -1,3 +1,4 @@
#include <c10/util/Exception.h>
#include <torch/csrc/export/upgrader.h>
#include <limits>
#include <map>
@ -23,34 +24,29 @@ static const std::multiset<Upgrader>& getUpgrader(int current_version) {
}
static nlohmann::json getFieldByKeypath(
const nlohmann::json& obj,
nlohmann::json obj,
const std::vector<std::string>& keypath) {
nlohmann::json current = obj;
for (const auto& key : keypath) {
if (!current.contains(key)) {
throw std::runtime_error("Keypath not found: " + key);
}
current = current[key];
TORCH_CHECK(obj.contains(key), "Keypath not found: " + key);
obj = obj[key];
}
return current;
return obj;
}
static void setFieldByKeypath(
nlohmann::json& obj,
const std::vector<std::string>& keypath,
const nlohmann::json& value) {
nlohmann::json value) {
nlohmann::json* current = &obj;
for (size_t i = 0; i < keypath.size() - 1; ++i) {
const auto& key = keypath[i];
if (!current->contains(key)) {
throw std::runtime_error("Keypath not found: " + key);
}
TORCH_CHECK(current->contains(key), "Keypath not found: " + key);
current = &((*current)[key]);
}
if (!current->contains(keypath.back())) {
throw std::runtime_error("Keypath not found: " + keypath.back());
}
(*current)[keypath.back()] = value;
TORCH_CHECK(
current->contains(keypath.back()),
"Keypath not found: " + keypath.back());
(*current)[keypath.back()] = std::move(value);
}
Upgrader::Upgrader(std::vector<std::string> kp, UpgraderFunction func)
@ -85,7 +81,7 @@ void registerUpgrader(
error_stream << ".";
error_stream << keypath[i];
}
throw std::runtime_error(error_stream.str());
TORCH_CHECK(false, error_stream.str());
}
}
}
@ -113,7 +109,7 @@ void registerUpgrader(
throw std::invalid_argument("Empty keypath provided");
}
registerUpgrader(version, keypath_vector, upgrade_func);
registerUpgrader(version, std::move(keypath_vector), upgrade_func);
}
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath) {
@ -176,18 +172,16 @@ void throwUpgraderError(
error_stream << "\nProblematic object: " << problematic_object.dump(2);
}
throw std::runtime_error(error_stream.str());
TORCH_CHECK(false, error_stream.str());
}
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
auto current_artifact = artifact;
nlohmann::json upgrade(nlohmann::json artifact, int target_version) {
// Validate that the artifact contains required schema version information
if (!current_artifact.contains("schema_version")) {
throw std::runtime_error("Missing schema_version field in artifact");
}
TORCH_CHECK(
artifact.contains("schema_version"),
"Missing schema_version field in artifact");
int current_version = current_artifact["schema_version"]["major"];
int current_version = artifact["schema_version"]["major"];
// Iteratively apply upgraders until target version is reached or no more are
// available
@ -204,14 +198,13 @@ nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
// (deeper keypaths first to prevent parent/child conflicts)
for (const auto& upgrader : upgraders) {
// Extract the field to be upgraded using its keypath
auto field_to_upgrade =
getFieldByKeypath(current_artifact, upgrader.keypath);
auto field_to_upgrade = getFieldByKeypath(artifact, upgrader.keypath);
// Apply the upgrade transformation
auto upgraded_field = upgrader.upgrade_func(field_to_upgrade);
auto upgraded_field = upgrader.upgrade_func(std::move(field_to_upgrade));
// Update the artifact with the upgraded field
setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field);
setFieldByKeypath(artifact, upgrader.keypath, upgraded_field);
}
// Move to the next version for potential additional upgrades
@ -219,11 +212,11 @@ nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
}
// Update schema version to reflect the final upgraded version
if (current_artifact["schema_version"]["major"] != current_version) {
current_artifact["schema_version"]["major"] = current_version;
if (artifact["schema_version"]["major"] != current_version) {
artifact["schema_version"]["major"] = current_version;
// Reset minor version to 0 - the correct minor version should be set
// when converting the json to in memory representation of ExportedProgram
current_artifact["schema_version"]["minor"] = 0;
artifact["schema_version"]["minor"] = 0;
}
// Validate that we reached the target version if requested
@ -233,10 +226,10 @@ nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
<< "Failed to upgrade to target version " << target_version
<< ". Final version reached: " << current_version
<< ". This may indicate missing upgraders for intermediate versions.";
throw std::runtime_error(error_stream.str());
TORCH_CHECK(false, error_stream.str());
}
return current_artifact;
return artifact;
}
} // namespace torch::_export

View File

@ -108,11 +108,12 @@ void throwUpgraderError(
/// e.g. adding a new field with default value, it's automatically handled by
/// the default constructor in generated_serialization_types.h.
///
/// @param artifact The JSON artifact to upgrade
/// @param artifact The JSON artifact to upgrade(passed by value: function
/// operates on a local copy, original remains unmodified)
/// @param target_version The target schema version to upgrade to
/// @return The upgraded JSON artifact with updated schema version
/// @throws std::runtime_error if artifact is missing schema_version field
/// @throws std::runtime_error if final version doesn't match target version
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version);
nlohmann::json upgrade(nlohmann::json artifact, int target_version);
} // namespace torch::_export