pytorch/test/cpp/aoti_inference/aoti_custom_class.cpp
Bin Bao 1cb2ebd740 [AOTI] Fix #140546 and support AOTI package load for Intel GPU. (#140664)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

* #140686
* __->__ #140664
* #140269
* #140268
* #135320
* #135318
* #139026

Fix #140546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140664
Approved by: https://github.com/desertfire, https://github.com/EikanWang
ghstack dependencies: #140268, #140269

Co-authored-by: Bin Bao <binbao@meta.com>
2024-12-10 05:05:08 +00:00

56 lines
1.7 KiB
C++

#include <stdexcept>
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#if defined(USE_CUDA) || defined(USE_ROCM)
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include "aoti_custom_class.h"
namespace torch::aot_inductor {
static auto registerMyAOTIClass =
torch::class_<MyAOTIClass>("aoti", "MyAOTIClass")
.def(torch::init<std::string, std::string>())
.def("forward", &MyAOTIClass::forward)
.def_pickle(
[](const c10::intrusive_ptr<MyAOTIClass>& self)
-> std::vector<std::string> {
std::vector<std::string> v;
v.push_back(self->lib_path());
v.push_back(self->device());
return v;
},
[](std::vector<std::string> params) {
return c10::make_intrusive<MyAOTIClass>(params[0], params[1]);
});
MyAOTIClass::MyAOTIClass(
const std::string& model_path,
const std::string& device)
: lib_path_(model_path), device_(device) {
if (device_ == "cpu") {
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
model_path.c_str());
#if defined(USE_CUDA) || defined(USE_ROCM)
} else if (device_ == "cuda") {
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
model_path.c_str());
#endif
#if defined(USE_XPU)
} else if (device_ == "xpu") {
runner_ = std::make_unique<torch::inductor::AOTIModelContainerRunnerXpu>(
model_path.c_str());
#endif
} else {
throw std::runtime_error("invalid device: " + device);
}
}
std::vector<torch::Tensor> MyAOTIClass::forward(
std::vector<torch::Tensor> inputs) {
return runner_->run(inputs);
}
} // namespace torch::aot_inductor