pytorch/test/cpp/aoti_inference/aoti_custom_class.h
Bin Bao 4946638f06 [AOTI] Add ABI-compatiblity tests (#123848)
Summary: In AOTInductor generated CPU model code, there can be direct references to some aten/c10 utility functions and data structures, e.g. at::vec and c10::Half. These are performance critical and thus it doesn't make sense to create C shim for them. Instead, we make sure they are implemented in a header-only way, and use this set of tests to guard future changes.

There are more header files to be updated, but we will do it in other followup PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123848
Approved by: https://github.com/jansel
ghstack dependencies: #123847
2024-04-19 00:51:24 +00:00

46 lines
910 B
C++

#pragma once
#include <memory>
#include <torch/torch.h>
namespace torch::inductor {
class AOTIModelContainerRunner;
} // namespace torch::inductor
namespace torch::aot_inductor {
class MyAOTIClass : public torch::CustomClassHolder {
public:
explicit MyAOTIClass(
const std::string& model_path,
const std::string& device = "cuda");
~MyAOTIClass() {}
MyAOTIClass(const MyAOTIClass&) = delete;
MyAOTIClass& operator=(const MyAOTIClass&) = delete;
MyAOTIClass& operator=(MyAOTIClass&&) = delete;
const std::string& lib_path() const {
return lib_path_;
}
const std::string& device() const {
return device_;
}
std::vector<torch::Tensor> forward(std::vector<torch::Tensor> inputs);
private:
const std::string lib_path_;
const std::string device_;
std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner_;
};
} // namespace torch::aot_inductor