mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Code Clean] Replace std::runtime_error with TORCH_CHECK (#163437)
Replace the runtime_error of the vallina C++ exceptions with TORCH_CEHCK Including: - torch/csrc/export - torch/csrc/cuda Fixes #148114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163437 Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
parent
bb0635d7dd
commit
058814794b
|
|
@ -4,6 +4,7 @@
|
|||
#include <ATen/native/ConvUtils.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/UniqueVoidPtr.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
|
@ -861,7 +862,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) {
|
|||
case TraceEntry::SEGMENT_MAP:
|
||||
return segment_map_s;
|
||||
}
|
||||
throw std::runtime_error("unreachable");
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
};
|
||||
|
||||
for (const auto& traceInfo : snapshot.device_traces) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/Context.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/cuda/memory_snapshot.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
|
|
@ -413,7 +414,7 @@ std::string _memory_snapshot_pickled() {
|
|||
case TraceEntry::SEGMENT_MAP:
|
||||
return segment_map_s;
|
||||
}
|
||||
throw std::runtime_error("unreachable");
|
||||
TORCH_CHECK(false, "unreachable");
|
||||
};
|
||||
|
||||
for (const auto& traceInfo : snapshot.device_traces) {
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
|
|||
case torch::cuda::nccl::ncclResult::NumResults:
|
||||
return ncclResult_t::ncclNumResults;
|
||||
default:
|
||||
throw std::runtime_error("Unconvertible NCCL type");
|
||||
TORCH_CHECK(false, "Unconvertible NCCL type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
|
|||
case ncclNumResults:
|
||||
return torch::cuda::nccl::ncclResult::NumResults;
|
||||
default:
|
||||
throw std::runtime_error("Unconvertible NCCL type");
|
||||
TORCH_CHECK(false, "Unconvertible NCCL type");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -194,10 +194,9 @@ static void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) {
|
|||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) {
|
||||
throw std::runtime_error(
|
||||
"NCCL timeout when waiting for nonblocking call to become successful.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
timeElapsed <= nccl_nonblocking_timeout(),
|
||||
"NCCL timeout when waiting for nonblocking call to become successful.");
|
||||
sched_yield(); // yield to other threads
|
||||
ncclCommGetAsyncError(to_nccl_comm(comm), &result);
|
||||
}
|
||||
|
|
@ -227,10 +226,9 @@ static void NCCL_CHECK_TIMEOUT(
|
|||
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
currentTimepoint - startTimepoint)
|
||||
.count();
|
||||
if (timeElapsed > nccl_nonblocking_timeout()) {
|
||||
throw std::runtime_error(
|
||||
"NCCL timeout when waiting for nonblocking call to become successful.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
timeElapsed <= nccl_nonblocking_timeout(),
|
||||
"NCCL timeout when waiting for nonblocking call to become successful.");
|
||||
sched_yield(); // yield to other threads
|
||||
ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result);
|
||||
} while (result == ncclInProgress);
|
||||
|
|
@ -258,7 +256,7 @@ void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
|
|||
std::ostringstream err;
|
||||
err << "NCCL Error " << static_cast<int>(status) << ": "
|
||||
<< ncclGetErrorString(to_nccl_result(status));
|
||||
throw std::runtime_error(err.str());
|
||||
TORCH_CHECK(false, err.str());
|
||||
}
|
||||
|
||||
struct NcclCommList {
|
||||
|
|
@ -318,41 +316,36 @@ static void check_tensor(
|
|||
int64_t ref_numel,
|
||||
ScalarType ref_dtype) {
|
||||
auto check_one = [&](const at::Tensor& tensor) {
|
||||
if (!tensor.is_cuda() || tensor.is_sparse()) {
|
||||
throw std::runtime_error(
|
||||
"input and output elements have to be cuda dense Tensors");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
tensor.is_cuda() && !tensor.is_sparse(),
|
||||
"input and output elements have to be cuda dense Tensors");
|
||||
|
||||
if (ref_dtype != tensor.scalar_type()) {
|
||||
throw std::runtime_error(
|
||||
"all inputs and outputs must be of the same Tensor dtype");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
ref_dtype == tensor.scalar_type(),
|
||||
"all inputs and outputs must be of the same Tensor dtype");
|
||||
|
||||
if (!tensor.is_contiguous()) {
|
||||
throw std::runtime_error("all inputs and outputs have to be contiguous");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
tensor.is_contiguous(), "all inputs and outputs have to be contiguous");
|
||||
};
|
||||
|
||||
check_one(input);
|
||||
|
||||
// all inputs must be same size
|
||||
if (input.numel() != ref_numel) {
|
||||
throw std::runtime_error(
|
||||
"all inputs must have the same number of elements");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
input.numel() == ref_numel,
|
||||
"all inputs must have the same number of elements");
|
||||
|
||||
if (output) {
|
||||
check_one(*output);
|
||||
|
||||
// inputs and outputs must be on same device respectively
|
||||
if (input.get_device() != output->get_device()) {
|
||||
throw std::runtime_error("input and output must be on the same device");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
input.get_device() == output->get_device(),
|
||||
"input and output must be on the same device");
|
||||
|
||||
if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
|
||||
throw std::runtime_error(
|
||||
"output must be of size input_size * size_multiplier");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
output->numel() * output_multiplier == ref_numel * input_multiplier,
|
||||
"output must be of size input_size * size_multiplier");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -364,15 +357,13 @@ void check_inputs(
|
|||
// len(inputs) == len(outputs)
|
||||
size_t len = inputs.size();
|
||||
|
||||
if (len == 0) {
|
||||
throw std::runtime_error("input sequence can't be empty");
|
||||
}
|
||||
TORCH_CHECK(len != 0, "input sequence can't be empty");
|
||||
|
||||
if (len != outputs.size()) {
|
||||
std::stringstream err;
|
||||
err << "inputs and outputs sequences have to be of the same length, but got input of length "
|
||||
<< len << " and output of length " << outputs.size();
|
||||
throw std::runtime_error(err.str());
|
||||
TORCH_CHECK(false, err.str());
|
||||
}
|
||||
|
||||
device_set devices;
|
||||
|
|
@ -388,9 +379,8 @@ void check_inputs(
|
|||
|
||||
auto input_device = input.get_device();
|
||||
// inputs must be on unique devices
|
||||
if (devices.test(input_device)) {
|
||||
throw std::runtime_error("inputs must be on unique devices");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!devices.test(input_device), "inputs must be on unique devices");
|
||||
devices.set(input_device);
|
||||
}
|
||||
}
|
||||
|
|
@ -403,9 +393,7 @@ void check_inputs(
|
|||
int output_multiplier) {
|
||||
auto len = inputs.size();
|
||||
|
||||
if (len <= 0) {
|
||||
throw std::runtime_error("input sequence can't be empty");
|
||||
}
|
||||
TORCH_CHECK(len > 0, "input sequence can't be empty");
|
||||
|
||||
device_set devices;
|
||||
int64_t numel = inputs[0].numel();
|
||||
|
|
@ -426,9 +414,8 @@ void check_inputs(
|
|||
|
||||
auto input_device = input.get_device();
|
||||
// inputs must be on unique devices
|
||||
if (devices.test(input_device)) {
|
||||
throw std::runtime_error("inputs must be on unique devices");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!devices.test(input_device), "inputs must be on unique devices");
|
||||
devices.set(input_device);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
using namespace at;
|
||||
|
|
@ -63,10 +64,9 @@ static std::vector<std::optional<at::cuda::CUDAStream>> unpack_streams(
|
|||
return std::vector<std::optional<at::cuda::CUDAStream>>(size, std::nullopt);
|
||||
}
|
||||
auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
|
||||
if (streams.size() != size) {
|
||||
throw std::runtime_error(
|
||||
"number of streams is not equal to number of inputs");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
streams.size() == size,
|
||||
"number of streams is not equal to number of inputs");
|
||||
return streams;
|
||||
}
|
||||
|
||||
|
|
@ -90,10 +90,9 @@ static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
|
|||
comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
|
||||
}
|
||||
}
|
||||
if (comms.size() != size) {
|
||||
throw std::runtime_error(
|
||||
"number of communicators is not equal to number of inputs");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
comms.size() == size,
|
||||
"number of communicators is not equal to number of inputs");
|
||||
return comms;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/Stream.h>
|
||||
#include <torch/csrc/cuda/THCP.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
|
@ -8,18 +9,17 @@
|
|||
// whatever the current stream of the device the input is associated with was.
|
||||
std::vector<std::optional<at::cuda::CUDAStream>>
|
||||
THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
|
||||
if (!PySequence_Check(obj)) {
|
||||
throw std::runtime_error(
|
||||
"Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
PySequence_Check(obj),
|
||||
"Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
|
||||
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
|
||||
if (seq.get() == nullptr) {
|
||||
throw std::runtime_error(
|
||||
"expected PySequence, but got " + std::string(THPUtils_typename(obj)));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
seq.get() != nullptr,
|
||||
"expected PySequence, but got " + std::string(THPUtils_typename(obj)));
|
||||
|
||||
std::vector<std::optional<at::cuda::CUDAStream>> streams;
|
||||
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
|
||||
streams.reserve(length);
|
||||
for (Py_ssize_t i = 0; i < length; i++) {
|
||||
PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i);
|
||||
|
||||
|
|
@ -34,7 +34,8 @@ THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
|
|||
} else if (stream == Py_None) {
|
||||
streams.emplace_back();
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unknown data type found in stream list. Need torch.cuda.Stream or None");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/export/upgrader.h>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
|
|
@ -23,34 +24,29 @@ static const std::multiset<Upgrader>& getUpgrader(int current_version) {
|
|||
}
|
||||
|
||||
static nlohmann::json getFieldByKeypath(
|
||||
const nlohmann::json& obj,
|
||||
nlohmann::json obj,
|
||||
const std::vector<std::string>& keypath) {
|
||||
nlohmann::json current = obj;
|
||||
for (const auto& key : keypath) {
|
||||
if (!current.contains(key)) {
|
||||
throw std::runtime_error("Keypath not found: " + key);
|
||||
}
|
||||
current = current[key];
|
||||
TORCH_CHECK(obj.contains(key), "Keypath not found: " + key);
|
||||
obj = obj[key];
|
||||
}
|
||||
return current;
|
||||
return obj;
|
||||
}
|
||||
|
||||
static void setFieldByKeypath(
|
||||
nlohmann::json& obj,
|
||||
const std::vector<std::string>& keypath,
|
||||
const nlohmann::json& value) {
|
||||
nlohmann::json value) {
|
||||
nlohmann::json* current = &obj;
|
||||
for (size_t i = 0; i < keypath.size() - 1; ++i) {
|
||||
const auto& key = keypath[i];
|
||||
if (!current->contains(key)) {
|
||||
throw std::runtime_error("Keypath not found: " + key);
|
||||
}
|
||||
TORCH_CHECK(current->contains(key), "Keypath not found: " + key);
|
||||
current = &((*current)[key]);
|
||||
}
|
||||
if (!current->contains(keypath.back())) {
|
||||
throw std::runtime_error("Keypath not found: " + keypath.back());
|
||||
}
|
||||
(*current)[keypath.back()] = value;
|
||||
TORCH_CHECK(
|
||||
current->contains(keypath.back()),
|
||||
"Keypath not found: " + keypath.back());
|
||||
(*current)[keypath.back()] = std::move(value);
|
||||
}
|
||||
|
||||
Upgrader::Upgrader(std::vector<std::string> kp, UpgraderFunction func)
|
||||
|
|
@ -85,7 +81,7 @@ void registerUpgrader(
|
|||
error_stream << ".";
|
||||
error_stream << keypath[i];
|
||||
}
|
||||
throw std::runtime_error(error_stream.str());
|
||||
TORCH_CHECK(false, error_stream.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -113,7 +109,7 @@ void registerUpgrader(
|
|||
throw std::invalid_argument("Empty keypath provided");
|
||||
}
|
||||
|
||||
registerUpgrader(version, keypath_vector, upgrade_func);
|
||||
registerUpgrader(version, std::move(keypath_vector), upgrade_func);
|
||||
}
|
||||
|
||||
bool deregisterUpgrader(int version, const std::vector<std::string>& keypath) {
|
||||
|
|
@ -176,18 +172,16 @@ void throwUpgraderError(
|
|||
error_stream << "\nProblematic object: " << problematic_object.dump(2);
|
||||
}
|
||||
|
||||
throw std::runtime_error(error_stream.str());
|
||||
TORCH_CHECK(false, error_stream.str());
|
||||
}
|
||||
|
||||
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
||||
auto current_artifact = artifact;
|
||||
|
||||
nlohmann::json upgrade(nlohmann::json artifact, int target_version) {
|
||||
// Validate that the artifact contains required schema version information
|
||||
if (!current_artifact.contains("schema_version")) {
|
||||
throw std::runtime_error("Missing schema_version field in artifact");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
artifact.contains("schema_version"),
|
||||
"Missing schema_version field in artifact");
|
||||
|
||||
int current_version = current_artifact["schema_version"]["major"];
|
||||
int current_version = artifact["schema_version"]["major"];
|
||||
|
||||
// Iteratively apply upgraders until target version is reached or no more are
|
||||
// available
|
||||
|
|
@ -204,14 +198,13 @@ nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
|||
// (deeper keypaths first to prevent parent/child conflicts)
|
||||
for (const auto& upgrader : upgraders) {
|
||||
// Extract the field to be upgraded using its keypath
|
||||
auto field_to_upgrade =
|
||||
getFieldByKeypath(current_artifact, upgrader.keypath);
|
||||
auto field_to_upgrade = getFieldByKeypath(artifact, upgrader.keypath);
|
||||
|
||||
// Apply the upgrade transformation
|
||||
auto upgraded_field = upgrader.upgrade_func(field_to_upgrade);
|
||||
auto upgraded_field = upgrader.upgrade_func(std::move(field_to_upgrade));
|
||||
|
||||
// Update the artifact with the upgraded field
|
||||
setFieldByKeypath(current_artifact, upgrader.keypath, upgraded_field);
|
||||
setFieldByKeypath(artifact, upgrader.keypath, upgraded_field);
|
||||
}
|
||||
|
||||
// Move to the next version for potential additional upgrades
|
||||
|
|
@ -219,11 +212,11 @@ nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
|||
}
|
||||
|
||||
// Update schema version to reflect the final upgraded version
|
||||
if (current_artifact["schema_version"]["major"] != current_version) {
|
||||
current_artifact["schema_version"]["major"] = current_version;
|
||||
if (artifact["schema_version"]["major"] != current_version) {
|
||||
artifact["schema_version"]["major"] = current_version;
|
||||
// Reset minor version to 0 - the correct minor version should be set
|
||||
// when converting the json to in memory representation of ExportedProgram
|
||||
current_artifact["schema_version"]["minor"] = 0;
|
||||
artifact["schema_version"]["minor"] = 0;
|
||||
}
|
||||
|
||||
// Validate that we reached the target version if requested
|
||||
|
|
@ -233,10 +226,10 @@ nlohmann::json upgrade(const nlohmann::json& artifact, int target_version) {
|
|||
<< "Failed to upgrade to target version " << target_version
|
||||
<< ". Final version reached: " << current_version
|
||||
<< ". This may indicate missing upgraders for intermediate versions.";
|
||||
throw std::runtime_error(error_stream.str());
|
||||
TORCH_CHECK(false, error_stream.str());
|
||||
}
|
||||
|
||||
return current_artifact;
|
||||
return artifact;
|
||||
}
|
||||
|
||||
} // namespace torch::_export
|
||||
|
|
|
|||
|
|
@ -108,11 +108,12 @@ void throwUpgraderError(
|
|||
/// e.g. adding a new field with default value, it's automatically handled by
|
||||
/// the default constructor in generated_serialization_types.h.
|
||||
///
|
||||
/// @param artifact The JSON artifact to upgrade
|
||||
/// @param artifact The JSON artifact to upgrade(passed by value: function
|
||||
/// operates on a local copy, original remains unmodified)
|
||||
/// @param target_version The target schema version to upgrade to
|
||||
/// @return The upgraded JSON artifact with updated schema version
|
||||
/// @throws std::runtime_error if artifact is missing schema_version field
|
||||
/// @throws std::runtime_error if final version doesn't match target version
|
||||
nlohmann::json upgrade(const nlohmann::json& artifact, int target_version);
|
||||
nlohmann::json upgrade(nlohmann::json artifact, int target_version);
|
||||
|
||||
} // namespace torch::_export
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user