[AOTInductor] [BE] Add macro for loading symbols in aoti runner (#149249)

Summary:
Add macro for loading symbols in aoti runner

Test Plan:
Existing tests

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149249
Approved by: https://github.com/chenyang78
This commit is contained in:
Mu-Chu Lee 2025-03-15 00:21:21 -07:00 committed by PyTorch MergeBot
parent 24cfeec2c7
commit 1157367c78

View File

@ -33,39 +33,37 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
const bool run_single_threaded) {
model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
create_func_ = reinterpret_cast<decltype(create_func_)>(
model_so_->sym("AOTInductorModelContainerCreateWithDevice"));
delete_func_ = reinterpret_cast<decltype(delete_func_)>(
model_so_->sym("AOTInductorModelContainerDelete"));
get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
run_func_ = reinterpret_cast<decltype(run_func_)>(model_so_->sym(
#define LOAD_SYMBOL(var, name_str) \
var = reinterpret_cast<decltype(var)>(model_so_->sym(name_str));
LOAD_SYMBOL(create_func_, "AOTInductorModelContainerCreateWithDevice")
LOAD_SYMBOL(delete_func_, "AOTInductorModelContainerDelete")
LOAD_SYMBOL(get_num_outputs_func_, "AOTInductorModelContainerGetNumOutputs")
LOAD_SYMBOL(
run_func_,
run_single_threaded ? "AOTInductorModelContainerRunSingleThreaded"
: "AOTInductorModelContainerRun"));
get_num_constants_func_ = reinterpret_cast<decltype(get_num_constants_func_)>(
model_so_->sym("AOTInductorModelContainerGetNumConstants"));
get_constant_name_func_ = reinterpret_cast<decltype(get_constant_name_func_)>(
model_so_->sym("AOTInductorModelContainerGetConstantName"));
get_constant_original_fqn_func_ =
reinterpret_cast<decltype(get_constant_original_fqn_func_)>(
model_so_->sym("AOTInductorModelContainerGetConstantOriginalFQN"));
get_constant_dtype_func_ =
reinterpret_cast<decltype(get_constant_dtype_func_)>(
model_so_->sym("AOTInductorModelContainerGetConstantDtype"));
update_constant_buffer_func_ =
reinterpret_cast<decltype(update_constant_buffer_func_)>(
model_so_->sym("AOTInductorModelContainerUpdateConstantBuffer"));
update_inactive_constant_buffer_func_ =
reinterpret_cast<decltype(update_inactive_constant_buffer_func_)>(
model_so_->sym(
"AOTInductorModelContainerUpdateInactiveConstantBuffer"));
run_const_fold_func_ = reinterpret_cast<decltype(run_const_fold_func_)>(
model_so_->sym("AOTInductorModelContainerRunConstantFolding"));
swap_constant_buffer_func_ =
reinterpret_cast<decltype(swap_constant_buffer_func_)>(
model_so_->sym("AOTInductorModelContainerSwapConstantBuffer"));
get_call_spec_func_ = reinterpret_cast<decltype(get_call_spec_func_)>(
model_so_->sym("AOTInductorModelContainerGetCallSpec"));
: "AOTInductorModelContainerRun")
LOAD_SYMBOL(
get_num_constants_func_, "AOTInductorModelContainerGetNumConstants")
LOAD_SYMBOL(
get_constant_name_func_, "AOTInductorModelContainerGetConstantName")
LOAD_SYMBOL(
get_constant_original_fqn_func_,
"AOTInductorModelContainerGetConstantOriginalFQN")
LOAD_SYMBOL(
get_constant_dtype_func_, "AOTInductorModelContainerGetConstantDtype")
LOAD_SYMBOL(
update_constant_buffer_func_,
"AOTInductorModelContainerUpdateConstantBuffer")
LOAD_SYMBOL(
update_inactive_constant_buffer_func_,
"AOTInductorModelContainerUpdateInactiveConstantBuffer")
LOAD_SYMBOL(
run_const_fold_func_, "AOTInductorModelContainerRunConstantFolding")
LOAD_SYMBOL(
swap_constant_buffer_func_, "AOTInductorModelContainerSwapConstantBuffer")
LOAD_SYMBOL(get_call_spec_func_, "AOTInductorModelContainerGetCallSpec")
#undef LOAD_SYMBOL
// Hack to find the json file name from the model so file
size_t lastindex = model_so_path.find_last_of('.');