mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTInductor] [BE] Add macro for loading symbols in aoti runner (#149249)
Summary: Add macro for loading symbols in aoti runner Test Plan: Existing tests Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/149249 Approved by: https://github.com/chenyang78
This commit is contained in:
parent
24cfeec2c7
commit
1157367c78
|
|
@ -33,39 +33,37 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
|
|||
const bool run_single_threaded) {
|
||||
model_so_ = std::make_unique<at::DynamicLibrary>(model_so_path.c_str());
|
||||
TORCH_CHECK(model_so_, "Failed to load model: ", model_so_path);
|
||||
create_func_ = reinterpret_cast<decltype(create_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerCreateWithDevice"));
|
||||
delete_func_ = reinterpret_cast<decltype(delete_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerDelete"));
|
||||
get_num_outputs_func_ = reinterpret_cast<decltype(get_num_outputs_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetNumOutputs"));
|
||||
run_func_ = reinterpret_cast<decltype(run_func_)>(model_so_->sym(
|
||||
|
||||
#define LOAD_SYMBOL(var, name_str) \
|
||||
var = reinterpret_cast<decltype(var)>(model_so_->sym(name_str));
|
||||
LOAD_SYMBOL(create_func_, "AOTInductorModelContainerCreateWithDevice")
|
||||
LOAD_SYMBOL(delete_func_, "AOTInductorModelContainerDelete")
|
||||
LOAD_SYMBOL(get_num_outputs_func_, "AOTInductorModelContainerGetNumOutputs")
|
||||
LOAD_SYMBOL(
|
||||
run_func_,
|
||||
run_single_threaded ? "AOTInductorModelContainerRunSingleThreaded"
|
||||
: "AOTInductorModelContainerRun"));
|
||||
get_num_constants_func_ = reinterpret_cast<decltype(get_num_constants_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetNumConstants"));
|
||||
get_constant_name_func_ = reinterpret_cast<decltype(get_constant_name_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetConstantName"));
|
||||
get_constant_original_fqn_func_ =
|
||||
reinterpret_cast<decltype(get_constant_original_fqn_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetConstantOriginalFQN"));
|
||||
get_constant_dtype_func_ =
|
||||
reinterpret_cast<decltype(get_constant_dtype_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetConstantDtype"));
|
||||
update_constant_buffer_func_ =
|
||||
reinterpret_cast<decltype(update_constant_buffer_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerUpdateConstantBuffer"));
|
||||
update_inactive_constant_buffer_func_ =
|
||||
reinterpret_cast<decltype(update_inactive_constant_buffer_func_)>(
|
||||
model_so_->sym(
|
||||
"AOTInductorModelContainerUpdateInactiveConstantBuffer"));
|
||||
run_const_fold_func_ = reinterpret_cast<decltype(run_const_fold_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerRunConstantFolding"));
|
||||
swap_constant_buffer_func_ =
|
||||
reinterpret_cast<decltype(swap_constant_buffer_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerSwapConstantBuffer"));
|
||||
get_call_spec_func_ = reinterpret_cast<decltype(get_call_spec_func_)>(
|
||||
model_so_->sym("AOTInductorModelContainerGetCallSpec"));
|
||||
: "AOTInductorModelContainerRun")
|
||||
LOAD_SYMBOL(
|
||||
get_num_constants_func_, "AOTInductorModelContainerGetNumConstants")
|
||||
LOAD_SYMBOL(
|
||||
get_constant_name_func_, "AOTInductorModelContainerGetConstantName")
|
||||
LOAD_SYMBOL(
|
||||
get_constant_original_fqn_func_,
|
||||
"AOTInductorModelContainerGetConstantOriginalFQN")
|
||||
LOAD_SYMBOL(
|
||||
get_constant_dtype_func_, "AOTInductorModelContainerGetConstantDtype")
|
||||
LOAD_SYMBOL(
|
||||
update_constant_buffer_func_,
|
||||
"AOTInductorModelContainerUpdateConstantBuffer")
|
||||
LOAD_SYMBOL(
|
||||
update_inactive_constant_buffer_func_,
|
||||
"AOTInductorModelContainerUpdateInactiveConstantBuffer")
|
||||
LOAD_SYMBOL(
|
||||
run_const_fold_func_, "AOTInductorModelContainerRunConstantFolding")
|
||||
LOAD_SYMBOL(
|
||||
swap_constant_buffer_func_, "AOTInductorModelContainerSwapConstantBuffer")
|
||||
LOAD_SYMBOL(get_call_spec_func_, "AOTInductorModelContainerGetCallSpec")
|
||||
#undef LOAD_SYMBOL
|
||||
|
||||
// Hack to find the json file name from the model so file
|
||||
size_t lastindex = model_so_path.find_last_of('.');
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user