mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[CodeClean] Replace std::runtime_error with TORCH_CHECK (#164130)
As the title stated. **Changes**: - torch/csrc/inductor(Part 1) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164130 Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
parent
09a4187b8e
commit
6f713e25bb
|
|
@ -102,11 +102,10 @@ std::string create_temp_dir() {
|
|||
}
|
||||
#else
|
||||
std::string temp_dir = "/tmp/XXXXXX";
|
||||
if (mkdtemp(temp_dir.data()) == nullptr) {
|
||||
throw std::runtime_error(
|
||||
std::string("Failed to create temporary directory: ") +
|
||||
c10::utils::str_error(errno));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
mkdtemp(temp_dir.data()) != nullptr,
|
||||
"Failed to create temporary directory: ",
|
||||
c10::utils::str_error(errno));
|
||||
return temp_dir;
|
||||
#endif
|
||||
}
|
||||
|
|
@ -156,9 +155,7 @@ namespace torch::inductor {
|
|||
|
||||
namespace {
|
||||
const nlohmann::json& load_json_file(const std::string& json_path) {
|
||||
if (!file_exists(json_path)) {
|
||||
throw std::runtime_error("File not found: " + json_path);
|
||||
}
|
||||
TORCH_CHECK(file_exists(json_path), "File not found: ", json_path);
|
||||
|
||||
std::ifstream json_file(json_path);
|
||||
TORCH_CHECK(json_file.is_open());
|
||||
|
|
@ -415,32 +412,25 @@ std::string compile_so(
|
|||
get_cpp_compile_command(filename, obj_filenames, linker_flags);
|
||||
|
||||
// Run the commands to generate a .so file
|
||||
int status = system(compile_cmd.c_str());
|
||||
if (status != 0) {
|
||||
throw std::runtime_error("Failed to compile cpp file.");
|
||||
}
|
||||
status = system(link_cmd.c_str());
|
||||
if (status != 0) {
|
||||
throw std::runtime_error("Failed to link files.");
|
||||
}
|
||||
TORCH_CHECK(system(compile_cmd.c_str()) == 0, "Failed to compile cpp file.");
|
||||
TORCH_CHECK(system(link_cmd.c_str()) == 0, "Failed to link files.");
|
||||
|
||||
// Move the mmapped weights onto the .so
|
||||
std::string serialized_weights_path = filename + "_serialized_weights.bin";
|
||||
if (file_exists(serialized_weights_path)) {
|
||||
std::ifstream serialized_weights_file(
|
||||
serialized_weights_path, std::ios::binary);
|
||||
if (!serialized_weights_file.is_open()) {
|
||||
throw std::runtime_error("Failed to open serialized weights file");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
serialized_weights_file.is_open(),
|
||||
"Failed to open serialized weights file");
|
||||
|
||||
std::vector<char> serialized_weights(
|
||||
(std::istreambuf_iterator<char>(serialized_weights_file)),
|
||||
std::istreambuf_iterator<char>());
|
||||
serialized_weights_file.close();
|
||||
|
||||
std::ofstream output_so_file(output_so, std::ios::binary | std::ios::app);
|
||||
if (!output_so_file.is_open()) {
|
||||
throw std::runtime_error("Failed to open output .so file");
|
||||
}
|
||||
TORCH_CHECK(output_so_file.is_open(), "Failed to open output .so file");
|
||||
// Page align the weights
|
||||
std::streampos so_size = output_so_file.tellp();
|
||||
std::vector<char> padding(16384 - so_size % 16384, ' ');
|
||||
|
|
@ -495,12 +485,11 @@ class RAIIMinizArchive {
|
|||
public:
|
||||
RAIIMinizArchive(const std::string& zip_path) {
|
||||
mz_zip_zero_struct(&_zip_archive);
|
||||
if (!mz_zip_reader_init_file(
|
||||
&_zip_archive, normalize_path_separator(zip_path).c_str(), 0)) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to initialize zip archive: {}",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive))));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
mz_zip_reader_init_file(
|
||||
&_zip_archive, normalize_path_separator(zip_path).c_str(), 0),
|
||||
"Failed to initialize zip archive: ",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)));
|
||||
}
|
||||
RAIIMinizArchive(const RAIIMinizArchive&) = delete;
|
||||
RAIIMinizArchive& operator=(const RAIIMinizArchive&) = delete;
|
||||
|
|
@ -522,18 +511,18 @@ class RAIIMinizArchive {
|
|||
// terminator
|
||||
const auto zip_filename_len{
|
||||
mz_zip_reader_get_filename(&_zip_archive, i, nullptr, 0)};
|
||||
if (!zip_filename_len) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to read zip filename length at index {}", i));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
zip_filename_len, "Failed to read zip filename length at index ", i);
|
||||
|
||||
// std::string implicitly appends a character for the null terminator
|
||||
std::string zip_filename(zip_filename_len - 1, '\0');
|
||||
if (!mz_zip_reader_get_filename(
|
||||
&_zip_archive, i, zip_filename.data(), zip_filename_len)) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to read zip filename at index {}", i));
|
||||
}
|
||||
zip_filenames.emplace_back(zip_filename);
|
||||
TORCH_CHECK(
|
||||
mz_zip_reader_get_filename(
|
||||
&_zip_archive, i, zip_filename.data(), zip_filename_len),
|
||||
"Failed to read zip filename at index ",
|
||||
i);
|
||||
|
||||
zip_filenames.emplace_back(std::move(zip_filename));
|
||||
}
|
||||
|
||||
return zip_filenames;
|
||||
|
|
@ -551,18 +540,25 @@ class RAIIMinizArchive {
|
|||
0)) {
|
||||
#ifdef _WIN32
|
||||
DWORD dwErrCode = GetLastError();
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to extract zip file {} to destination file {}, error code: {}, mz_zip error string: {}",
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Failed to extract zip file ",
|
||||
zip_filename,
|
||||
" to destination file ",
|
||||
path_dest_filename,
|
||||
", error code: ",
|
||||
dwErrCode,
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive))));
|
||||
" mz_zip error string: ",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)));
|
||||
#else
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to extract zip file {} to destination file {}, mz_zip error string: {}",
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Failed to extract zip file ",
|
||||
zip_filename,
|
||||
" to destination file ",
|
||||
path_dest_filename,
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive))));
|
||||
", mz_zip error string: ",
|
||||
mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
@ -578,9 +574,7 @@ std::unordered_map<std::string, std::string> AOTIModelPackageLoader::
|
|||
// Open the zip archive
|
||||
RAIIMinizArchive zip_archive{model_package_path};
|
||||
auto found_filenames{zip_archive.get_filenames()};
|
||||
if (found_filenames.empty()) {
|
||||
throw std::runtime_error("No files found in zip archive.");
|
||||
}
|
||||
TORCH_CHECK(!found_filenames.empty(), "No files found in zip archive.");
|
||||
|
||||
// Find the file prefix (similar to constructor logic)
|
||||
std::string file_prefix;
|
||||
|
|
@ -624,15 +618,13 @@ std::unordered_map<std::string, std::string> AOTIModelPackageLoader::
|
|||
model_names_str += model_name_tmp + "\n";
|
||||
}
|
||||
|
||||
throw std::runtime_error(
|
||||
"Failed to find a generated cpp file or so file for model '" +
|
||||
model_name +
|
||||
"' in the zip archive.\n\n"
|
||||
"Available models in the archive:\n" +
|
||||
model_names_str +
|
||||
"\n\n"
|
||||
"To load a specific model, please provide its name using the `model_name` parameter when calling AOTIModelPackageLoader() or torch._inductor.package.load_package.\n\n"
|
||||
"The following files were loaded from the archive:\n" +
|
||||
TORCH_CHECK(
|
||||
"Failed to find a generated cpp file or so file for model '",
|
||||
model_name,
|
||||
"' in the zip archive.\n\nAvailable models in the archive:\n",
|
||||
model_names_str,
|
||||
"\n\nTo load a specific model, please provide its name using the `model_name` parameter when calling AOTIModelPackageLoader() or torch._inductor.package.load_package.\n\n",
|
||||
"The following files were loaded from the archive:\n",
|
||||
found_filenames_str);
|
||||
}
|
||||
|
||||
|
|
@ -643,17 +635,15 @@ std::unordered_map<std::string, std::string> AOTIModelPackageLoader::
|
|||
|
||||
// Create the parent directory if it doesn't exist
|
||||
size_t parent_path_idx = output_path_str.find_last_of(k_separator);
|
||||
if (parent_path_idx == std::string::npos) {
|
||||
throw std::runtime_error(
|
||||
"Failed to find parent path in " + output_path_str);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
parent_path_idx != std::string::npos,
|
||||
"Failed to find parent path in " + output_path_str);
|
||||
std::string parent_path = output_path_str.substr(0, parent_path_idx);
|
||||
if (!recursive_mkdir(parent_path)) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to create directory {}: {}",
|
||||
parent_path,
|
||||
c10::utils::str_error(errno)));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
recursive_mkdir(parent_path),
|
||||
"Failed to create directory " + parent_path,
|
||||
": ",
|
||||
c10::utils::str_error(errno));
|
||||
|
||||
LOG(INFO) << "Extract file: " << metadata_filename << " to "
|
||||
<< output_path_str;
|
||||
|
|
@ -679,23 +669,19 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
const size_t num_runners,
|
||||
const c10::DeviceIndex device_index) {
|
||||
if (run_single_threaded) {
|
||||
if (num_runners != 1) {
|
||||
throw std::runtime_error(
|
||||
"num_runners must be 1 when run_single_threaded is true");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_runners == 1,
|
||||
"num_runners must be 1 when run_single_threaded is true");
|
||||
} else {
|
||||
if (num_runners < 1) {
|
||||
throw std::runtime_error(
|
||||
"num_runners must be >=1 when run_single_threaded is false");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_runners >= 1,
|
||||
"num_runners must be >=1 when run_single_threaded is false");
|
||||
}
|
||||
|
||||
// Extract all files within the zipfile to a temporary directory
|
||||
RAIIMinizArchive zip_archive{model_package_path};
|
||||
auto found_filenames{zip_archive.get_filenames()};
|
||||
if (found_filenames.empty()) {
|
||||
throw std::runtime_error("No files found in zip archive.");
|
||||
}
|
||||
TORCH_CHECK(!found_filenames.empty(), "No files found in zip archive.");
|
||||
|
||||
// All the paths are prepended with a tmp/ directory. We need to find the
|
||||
// prefix.
|
||||
|
|
@ -758,17 +744,16 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
|
||||
// Create the parent directory if it doesn't exist
|
||||
size_t parent_path_idx = output_file_path.find_last_of(k_separator);
|
||||
if (parent_path_idx == std::string::npos) {
|
||||
throw std::runtime_error(
|
||||
"Failed to find parent path in " + output_file_path);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
parent_path_idx != std::string::npos,
|
||||
"Failed to find parent path in " + output_file_path);
|
||||
|
||||
std::string parent_path = output_file_path.substr(0, parent_path_idx);
|
||||
if (!recursive_mkdir(parent_path)) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to create directory {}: {}",
|
||||
parent_path,
|
||||
c10::utils::str_error(errno)));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
recursive_mkdir(parent_path),
|
||||
"Failed to create directory " + parent_path,
|
||||
": ",
|
||||
c10::utils::str_error(errno));
|
||||
|
||||
// Extracts file to the temp directory
|
||||
zip_archive.extract_file(zip_filename_str, output_path_str);
|
||||
|
|
@ -801,15 +786,14 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
model_names_str += model_name_tmp + "\n";
|
||||
}
|
||||
|
||||
throw std::runtime_error(
|
||||
"Failed to find a generated cpp file or so file for model '" +
|
||||
model_name +
|
||||
"' in the zip archive.\n\n"
|
||||
"Available models in the archive:\n" +
|
||||
model_names_str +
|
||||
"\n\n"
|
||||
"To load a specific model, please provide its name using the `model_name` parameter when calling AOTIModelPackageLoader() or torch._inductor.package.load_package.\n\n"
|
||||
"The following files were loaded from the archive:\n" +
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Failed to find a generated cpp file or so file for model '",
|
||||
model_name,
|
||||
"' in the zip archive.\n\nAvailable models in the archive:\n",
|
||||
model_names_str,
|
||||
"\n\nTo load a specific model, please provide its name using the `model_name` parameter when calling AOTIModelPackageLoader() or torch._inductor.package.load_package.\n\n",
|
||||
"The following files were loaded from the archive:\n",
|
||||
found_filenames_str);
|
||||
}
|
||||
|
||||
|
|
@ -823,17 +807,15 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
|
||||
// Construct the runner depending on the device information
|
||||
std::string device_key = metadata_["AOTI_DEVICE_KEY"];
|
||||
|
||||
if (device_key.empty()) {
|
||||
throw std::runtime_error("No device information found.");
|
||||
}
|
||||
TORCH_CHECK(!device_key.empty(), "No device information found.");
|
||||
|
||||
std::unordered_map<std::string, CreateAOTIModelRunnerFunc>
|
||||
registered_aoti_runner = getAOTIModelRunnerRegistry();
|
||||
|
||||
if (registered_aoti_runner.find(device_key) == registered_aoti_runner.end()) {
|
||||
throw std::runtime_error("Unsupported device key found: " + device_key);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
registered_aoti_runner.find(device_key) != registered_aoti_runner.end(),
|
||||
"Unsupported device key found: ",
|
||||
device_key);
|
||||
|
||||
c10::Device device = c10::Device(device_key);
|
||||
device.set_index(device_index);
|
||||
|
|
@ -896,7 +878,7 @@ void AOTIModelPackageLoader::load_constants(
|
|||
if (fqn_to_constant_name.find(it.first) != fqn_to_constant_name.end()) {
|
||||
updated_constants_map.emplace(fqn_to_constant_name[it.first], it.second);
|
||||
} else {
|
||||
throw std::runtime_error("Constant not found: " + it.first);
|
||||
TORCH_CHECK(false, "Constant not found: ", it.first);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,15 +29,13 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
|
|||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
if (run_single_threaded) {
|
||||
if (num_models != 1) {
|
||||
throw std::runtime_error(
|
||||
"num_models must be 1 when run_single_threaded is true");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_models == 1,
|
||||
"num_models must be 1 when run_single_threaded is true");
|
||||
} else {
|
||||
if (num_models < 1) {
|
||||
throw std::runtime_error(
|
||||
"num_models must be >=1 when run_single_threaded is false");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_models >= 1,
|
||||
"num_models must be >=1 when run_single_threaded is false");
|
||||
}
|
||||
model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
|
||||
TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
|
||||
|
|
@ -86,11 +84,10 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
|
|||
? "AOTInductorModelContainerRunSingleThreaded"
|
||||
: "AOTInductorModelContainerRun";
|
||||
TRY_LOAD_SYMBOL(run_func_, run_func_name)
|
||||
if (run_func_ == nullptr && run_single_threaded) {
|
||||
throw std::runtime_error(
|
||||
"No AOTInductorModelContainerRunSingleThreaded function in .so! To use AOTInductor-compiled model in the single-threaded mode,\
|
||||
TORCH_CHECK(
|
||||
run_func_ != nullptr || !run_single_threaded,
|
||||
"No AOTInductorModelContainerRunSingleThreaded function in .so! To use AOTInductor-compiled model in the single-threaded mode,\
|
||||
consider rebuild your model with the latest AOTInductor.");
|
||||
}
|
||||
|
||||
TRY_LOAD_SYMBOL(
|
||||
free_inactive_constant_buffer_func_,
|
||||
|
|
@ -366,10 +363,9 @@ void AOTIModelContainerRunner::swap_constant_buffer() {
|
|||
}
|
||||
|
||||
void AOTIModelContainerRunner::free_inactive_constant_buffer() {
|
||||
if (!free_inactive_constant_buffer_func_) {
|
||||
throw std::runtime_error(
|
||||
"No free_inactive_constant_buffer in .so! Consider rebuild your model with the latest AOTInductor.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
free_inactive_constant_buffer_func_ != nullptr,
|
||||
"No free_inactive_constant_buffer in .so! Consider rebuild your model with the latest AOTInductor.");
|
||||
AOTI_RUNTIME_ERROR_CODE_CHECK(
|
||||
free_inactive_constant_buffer_func_(container_handle_));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,9 +25,8 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_cpu(
|
|||
const std::string& device_str,
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
if (device_str != "cpu") {
|
||||
throw std::runtime_error("Incorrect device passed to aoti_runner_cpu");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
device_str == "cpu", "Incorrect device passed to aoti_runner_cpu");
|
||||
return std::make_unique<AOTIModelContainerRunnerCpu>(
|
||||
model_so_path, num_models, run_single_threaded);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,9 +23,8 @@ std::unique_ptr<AOTIModelContainerRunner> create_aoti_runner_mps(
|
|||
const std::string& device_str,
|
||||
const std::string& cubin_dir,
|
||||
const bool run_single_threaded) {
|
||||
if (device_str != "mps") {
|
||||
throw std::runtime_error("Incorrect device passed to aoti_runner_mps");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
device_str == "mps", "Incorrect device passed to aoti_runner_mps");
|
||||
return std::make_unique<AOTIModelContainerRunnerMps>(
|
||||
model_so_path, num_models, run_single_threaded);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,11 +77,11 @@ void convert_handles_to_inputs(
|
|||
|
||||
template <typename T>
|
||||
void assert_numel(const ArrayRefTensor<T>& tensor, uint64_t numel) {
|
||||
if (tensor.numel() != numel) {
|
||||
std::stringstream err;
|
||||
err << "incorrect numel for input tensor. expected " << numel << ", got "
|
||||
<< tensor.numel();
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
tensor.numel() == numel,
|
||||
"incorrect numel for input tensor. expected ",
|
||||
numel,
|
||||
", got ",
|
||||
tensor.numel());
|
||||
}
|
||||
} // namespace torch::aot_inductor
|
||||
|
|
|
|||
|
|
@ -657,8 +657,8 @@ inline at::vec::Vectorized<float> vec_shuffle_down(
|
|||
case 4:
|
||||
return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1)));
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Unhandled vec_shuffle_down value " + std::to_string(n));
|
||||
|
||||
TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -682,8 +682,8 @@ inline at::vec::Vectorized<float> vec_shuffle_down(
|
|||
return vec_t(_mm512_permutexvar_ps(
|
||||
_mm512_set_epi32(8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8), x));
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Unhandled vec_shuffle_down value " + std::to_string(n));
|
||||
|
||||
TORCH_CHECK(false, "Unhandled vec_shuffle_down value ", n);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user