mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Standalone compile API in _Exporter (#158139)
Given an `package: _ExportPackage`, users can get a ready-to-use workspace in `tmp_dir` by calling:
```python
package._compiled_and_package(
tmp_dir + "/pt2_pacakge_name.pt2", True, package_example_inputs = True
)
```
`tmp_dir` will contains:
- `main.cpp` (an example cpp file that create the models, if package_example_inputs is True, it'll also load the example inputs and run the models)
- `CMakeLists.txt`
- `pt2_pacakge_name/` (this is where the models are)
- `pt2_pacakge_name.pt2`
- `inputs.pt` files if package_example_inputs is True
Remaining TODOs
- support loading contants/weights
- the `package_example_inputs = True` option only supports a list of Tensors for now
- eventually we should remove the `torch` dependency, and use `SlimTensor`/`StableIValue` instead.
Test Plan:
```
python test/inductor/test_aot_inductor_package.py -k test_compile_with_exporter
```
Example generated `main.cpp`:
```cpp
#include <dlfcn.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <torch/torch.h>
#include <vector>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include "package/data/aotinductor/Plus__default/Plus__default.h"
#include "package/data/aotinductor/Minus__default/Minus__default.h"
using torch::aot_inductor::AOTInductorModelPlus__default;
using torch::aot_inductor::AOTInductorModelMinus__default;
using torch::aot_inductor::ConstantHandle;
using torch::aot_inductor::ConstantMap;
int main(int argc, char* argv[]) {
std::string device_str = "cpu";
try {
c10::Device device(device_str);
// Load input tensors for model Plus__default
std::vector<at::Tensor> input_tensors1;
for (int j = 0; j < 2; ++j) {
std::string filename = "Plus__default_input_" + std::to_string(j) + ".pt";
std::ifstream in(filename, std::ios::binary);
if (!in.is_open()) {
std::cerr << "Failed to open file: " << filename << std::endl;
return 1;
}
std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
torch::IValue ivalue = torch::pickle_load(buffer);
input_tensors1.push_back(ivalue.toTensor().to(device));
}
// Load input tensors for model Minus__default
std::vector<at::Tensor> input_tensors2;
for (int j = 0; j < 2; ++j) {
std::string filename = "Minus__default_input_" + std::to_string(j) + ".pt";
std::ifstream in(filename, std::ios::binary);
if (!in.is_open()) {
std::cerr << "Failed to open file: " << filename << std::endl;
return 1;
}
std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
torch::IValue ivalue = torch::pickle_load(buffer);
input_tensors2.push_back(ivalue.toTensor().to(device));
}
// Create array of input handles
auto input_handles1 =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors1);
auto input_handles2 =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors2);
// Create array for output handles
AtenTensorHandle output_handle1;
AtenTensorHandle output_handle2;
// Create and load models
auto constants_map1 = std::make_shared<ConstantMap>();
auto constants_array1 = std::make_shared<std::vector<ConstantHandle>>();
auto model1 = AOTInductorModelPlus__default::Create(
constants_map1, constants_array1, device_str,
"package/data/aotinductor/Plus__default/");
model1->load_constants();
auto constants_map2 = std::make_shared<ConstantMap>();
auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
auto model2 = AOTInductorModelMinus__default::Create(
constants_map2, constants_array2, device_str,
"package/data/aotinductor/Minus__default/");
model2->load_constants();
// Run the models
torch::aot_inductor::DeviceStreamType stream1 = nullptr;
model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
torch::aot_inductor::DeviceStreamType stream2 = nullptr;
model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);
// Convert output handles to tensors
auto output_tensor1 =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle1, 1);
auto output_tensor2 =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle2, 1);
// Validate outputs
std::cout << "output_tensor1" << output_tensor1 << std::endl;
std::cout << "output_tensor2" << output_tensor2 << std::endl;
return 0;
} catch (const std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
}
```
Rollback Plan:
Differential Revision: D78124705
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158139
Approved by: https://github.com/desertfire
This commit is contained in:
parent
46915b1361
commit
cf3247b74a
|
|
@ -20,6 +20,7 @@ from torch._inductor.package import AOTICompiledModel, load_package, package_aot
|
|||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.export import Dim
|
||||
from torch.export.experimental import _ExportPackage
|
||||
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -31,20 +32,6 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
|
||||
|
||||
try:
|
||||
from test_static_linkage_utils import (
|
||||
get_static_linkage_main_cpp_file,
|
||||
get_static_linkage_makelist_file_cpu,
|
||||
get_static_linkage_makelist_file_cuda,
|
||||
)
|
||||
except ImportError:
|
||||
from .test_static_linkage_utils import (
|
||||
get_static_linkage_main_cpp_file,
|
||||
get_static_linkage_makelist_file_cpu,
|
||||
get_static_linkage_makelist_file_cuda,
|
||||
)
|
||||
|
||||
|
||||
def skipif(predicate: Callable[[str, bool], bool], reason: str):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
|
|
@ -153,6 +140,28 @@ class TestAOTInductorPackage(TestCase):
|
|||
if shutil.which("make") is None:
|
||||
raise unittest.SkipTest("make is not available")
|
||||
|
||||
def cmake_compile_and_run(self, base_dir):
|
||||
custom_env = os.environ.copy()
|
||||
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
||||
build_path = Path(base_dir) / "build"
|
||||
build_path.mkdir()
|
||||
subprocess.run(
|
||||
["cmake", ".."],
|
||||
cwd=build_path,
|
||||
env=custom_env,
|
||||
check=True,
|
||||
)
|
||||
subprocess.run(["make"], cwd=build_path, check=True)
|
||||
result = subprocess.run(
|
||||
["./build/main"],
|
||||
cwd=base_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def cmake_compile(self, model, example_inputs, options, tmp_dir):
|
||||
"""
|
||||
Exports model, compiles it using AOTInductor, extracts the
|
||||
|
|
@ -412,7 +421,7 @@ class TestAOTInductorPackage(TestCase):
|
|||
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
|
||||
@skipIfRocm # doesn't support multi-arch binary
|
||||
@skipIfXpu # doesn't support multi-arch binary
|
||||
def test_run_static_linkage_model(self):
|
||||
def test_compile_with_exporter(self):
|
||||
self.check_package_cpp_only()
|
||||
|
||||
class Model1(torch.nn.Module):
|
||||
|
|
@ -423,64 +432,45 @@ class TestAOTInductorPackage(TestCase):
|
|||
def forward(self, x, y):
|
||||
return x - y
|
||||
|
||||
def default(*args, **kwargs):
|
||||
return None
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.ones(3, 3).to(self.device),
|
||||
torch.ones(3, 3).to(self.device),
|
||||
)
|
||||
|
||||
model1 = Model1().to(self.device)
|
||||
model2 = Model2().to(self.device)
|
||||
package = _ExportPackage()
|
||||
m1 = Model1()
|
||||
m2 = Model2()
|
||||
exporter1 = package._exporter("Plus", m1)._define_overload("default", default)
|
||||
exporter2 = package._exporter("Minus", m2)._define_overload("default", default)
|
||||
exporter1(*example_inputs)
|
||||
exporter2(*example_inputs)
|
||||
|
||||
models = [model1, model2]
|
||||
|
||||
i = 0
|
||||
model_names = ["Plus", "Minus"]
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmp_dir,
|
||||
):
|
||||
for i in range(2):
|
||||
model = models[i]
|
||||
# TODO: should be done through _ExportPackage
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep,
|
||||
inductor_configs={
|
||||
"aot_inductor.compile_standalone": True,
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.model_name_for_generated_files": model_names[i],
|
||||
},
|
||||
for package_example_inputs in [True, False]:
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmp_dir,
|
||||
):
|
||||
package._compiled_and_package(
|
||||
tmp_dir + "/package.pt2", True, package_example_inputs
|
||||
)
|
||||
with (
|
||||
zipfile.ZipFile(package_path, "r") as zip_ref,
|
||||
):
|
||||
zip_ref.extractall(tmp_dir)
|
||||
|
||||
file_str = get_static_linkage_main_cpp_file()
|
||||
with open(Path(tmp_dir) / "main.cpp", "w") as f:
|
||||
f.write(file_str)
|
||||
|
||||
if self.device == GPU_TYPE:
|
||||
cmake_file_str = get_static_linkage_makelist_file_cuda()
|
||||
else:
|
||||
cmake_file_str = get_static_linkage_makelist_file_cpu()
|
||||
with open(Path(tmp_dir) / "CMakeLists.txt", "w") as f:
|
||||
f.write(cmake_file_str)
|
||||
|
||||
build_path = Path(tmp_dir) / "build"
|
||||
build_path.mkdir()
|
||||
custom_env = os.environ.copy()
|
||||
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
|
||||
subprocess.run(
|
||||
["cmake", ".."],
|
||||
cwd=build_path,
|
||||
env=custom_env,
|
||||
)
|
||||
|
||||
subprocess.run(["make"], cwd=build_path, check=True)
|
||||
subprocess.run(
|
||||
["./main", f"{tmp_dir}/", self.device], cwd=build_path, check=True
|
||||
)
|
||||
# Test compiling generated files
|
||||
result = self.cmake_compile_and_run(tmp_dir)
|
||||
if package_example_inputs:
|
||||
if self.device == GPU_TYPE:
|
||||
self.assertEqual(
|
||||
result.stdout,
|
||||
"output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2 0 0 0\n"
|
||||
" 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n",
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
result.stdout,
|
||||
"output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2 0 0 0\n"
|
||||
" 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n",
|
||||
)
|
||||
|
||||
def test_metadata(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,157 +0,0 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
def get_static_linkage_main_cpp_file():
|
||||
return """
|
||||
#include <dlfcn.h>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
// Include the AOTInductor headers
|
||||
#include "Minus.wrapper/data/aotinductor/model/Minus.h"
|
||||
#include "Plus.wrapper/data/aotinductor/model/Plus.h"
|
||||
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
|
||||
using torch::aot_inductor::AOTInductorModelMinus;
|
||||
using torch::aot_inductor::AOTInductorModelPlus;
|
||||
using torch::aot_inductor::ConstantHandle;
|
||||
using torch::aot_inductor::ConstantMap;
|
||||
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 2) {
|
||||
std::cerr
|
||||
<< "Usage: ./main <path> <device>"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
std::string path = argv[1];
|
||||
std::string device_str = argv[2];
|
||||
try {
|
||||
torch::Device device(device_str);
|
||||
|
||||
// Create two input tensors (10x10)
|
||||
auto tensor1 = torch::ones({10, 10}, device);
|
||||
auto tensor2 = torch::ones({10, 10}, device);
|
||||
// Create two input tensors (10x10)
|
||||
auto tensor3 = torch::ones({10, 10}, device);
|
||||
auto tensor4 = torch::ones({10, 10}, device);
|
||||
|
||||
std::vector<at::Tensor> input_tensors = {tensor1, tensor2};
|
||||
std::vector<at::Tensor> input_tensors2 = {tensor3, tensor4};
|
||||
|
||||
// Create array of input handles
|
||||
auto input_handles1 =
|
||||
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
|
||||
input_tensors);
|
||||
auto input_handles2 =
|
||||
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
|
||||
input_tensors2);
|
||||
|
||||
// Create array for output handle
|
||||
AtenTensorHandle output_handle1;
|
||||
AtenTensorHandle output_handle2;
|
||||
|
||||
auto constants_map = std::make_shared<ConstantMap>();
|
||||
auto constants_array = std::make_shared<std::vector<ConstantHandle>>();
|
||||
auto model1 = AOTInductorModelPlus::Create(
|
||||
constants_map, constants_array, device_str,
|
||||
path + "Plus.wrapper/data/"
|
||||
"aotinductor/model/");
|
||||
model1->load_constants();
|
||||
|
||||
auto constants_map2 = std::make_shared<ConstantMap>();
|
||||
auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
|
||||
auto model2 = AOTInductorModelMinus::Create(
|
||||
constants_map2, constants_array2, device_str,
|
||||
path + "Minus.wrapper/data/"
|
||||
"aotinductor/model/");
|
||||
model2->load_constants();
|
||||
|
||||
// Run the model
|
||||
torch::aot_inductor::DeviceStreamType stream1 = nullptr;
|
||||
torch::aot_inductor::DeviceStreamType stream2 = nullptr;
|
||||
model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
|
||||
model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);
|
||||
|
||||
// Convert output handle to tensor
|
||||
auto output_tensor1 =
|
||||
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||
&output_handle1, 1);
|
||||
auto output_tensor2 =
|
||||
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
|
||||
&output_handle2, 1);
|
||||
|
||||
if (!(torch::all(output_tensor1[0] == 2).item<bool>())){
|
||||
std::cout << "Wrong Output for Plus Model: " << output_tensor1 << std::endl;
|
||||
throw std::runtime_error("Tensor does not contain only the expected value 2.");
|
||||
}
|
||||
if (!(torch::all(output_tensor2[0] == 0).item<bool>())){
|
||||
std::cout << "Wrong Output for Minus Model: " << output_tensor1 << std::endl;
|
||||
throw std::runtime_error("Tensor does not contain only the expected value 0.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def get_static_linkage_makelist_file_cuda():
|
||||
return """
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
project(TestProject)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(CUDA REQUIRED)
|
||||
|
||||
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
|
||||
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
|
||||
|
||||
# Create executable
|
||||
add_executable(main main.cpp)
|
||||
|
||||
target_compile_definitions(main PRIVATE USE_CUDA)
|
||||
|
||||
target_link_libraries(main PRIVATE torch cuda
|
||||
${CUDA_LIBRARIES}
|
||||
Plus
|
||||
Minus)
|
||||
"""
|
||||
|
||||
|
||||
def get_static_linkage_makelist_file_cpu():
|
||||
return """
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
project(TestProject)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
|
||||
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
|
||||
|
||||
# Create executable
|
||||
add_executable(main main.cpp)
|
||||
|
||||
target_link_libraries(main PRIVATE torch
|
||||
Plus
|
||||
Minus)
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -1,14 +1,22 @@
|
|||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import typing
|
||||
import typing_extensions
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file
|
||||
from torch.export.exported_program import _decompose_exported_program
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def _copy_graph_module_and_signature(
|
||||
ep: torch.fx.GraphModule,
|
||||
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
|
||||
|
|
@ -333,18 +341,79 @@ class _ExportPackage:
|
|||
for overload, ep in method_data.overloads.items():
|
||||
yield f"{method}:{overload}", ep
|
||||
|
||||
def _compiled_and_package(self, f: torch.types.FileLike) -> None:
|
||||
options = {
|
||||
def _compiled_and_package(
|
||||
self,
|
||||
f: torch.types.FileLike,
|
||||
standalone: bool = False,
|
||||
package_example_inputs: bool = False,
|
||||
) -> None:
|
||||
options: dict[str, typing.Any] = {
|
||||
"aot_inductor.package": True,
|
||||
"aot_inductor.package_cpp_only": True,
|
||||
"always_keep_tensor_constants": True,
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
"aot_inductor.compile_standalone": standalone,
|
||||
}
|
||||
weights_map = {}
|
||||
aoti_files_map = {}
|
||||
model_names = []
|
||||
for name, ep in self._method_overloads:
|
||||
weights = torch._inductor.aot_compile(ep.module(), (), options=options) # type: ignore[arg-type]
|
||||
weights_map[name] = weights
|
||||
torch._inductor.package.package.package_aoti(
|
||||
name = name.replace(":", "__")
|
||||
model_names.append(name)
|
||||
options["aot_inductor.model_name_for_generated_files"] = name
|
||||
aoti_files = torch._inductor.aot_compile(
|
||||
ep.module(), # type: ignore[arg-type]
|
||||
ep.example_inputs[0],
|
||||
kwargs=ep.example_inputs[1],
|
||||
options=options,
|
||||
)
|
||||
aoti_files_map[name] = aoti_files
|
||||
|
||||
from torch._inductor.package import package
|
||||
|
||||
pt2_path = package.package_aoti(
|
||||
f,
|
||||
weights_map, # type: ignore[arg-type]
|
||||
aoti_files_map, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if not standalone:
|
||||
return
|
||||
|
||||
assert isinstance(pt2_path, str)
|
||||
base_directory = os.path.dirname(pt2_path)
|
||||
package_name = os.path.basename(pt2_path)[:-4]
|
||||
with (
|
||||
zipfile.ZipFile(pt2_path, "r") as zip_ref,
|
||||
):
|
||||
zip_ref.extractall(base_directory)
|
||||
|
||||
example_inputs_map: typing.Optional[dict[str, int]] = (
|
||||
{} if package_example_inputs else None
|
||||
)
|
||||
use_cuda = False
|
||||
for name, ep in self._method_overloads:
|
||||
name = name.replace(":", "__")
|
||||
# TODO: also dump kwargs
|
||||
# TODO: currently only support list of Tensors and they need to be on the same device
|
||||
if not ep.example_inputs:
|
||||
continue
|
||||
for inp in ep.example_inputs[0]:
|
||||
if isinstance(inp, torch.Tensor) and inp.device.type == "cuda":
|
||||
# TODO: more carefully determine the device type
|
||||
use_cuda = True
|
||||
if package_example_inputs:
|
||||
assert example_inputs_map is not None
|
||||
example_inputs_map[name] = len(ep.example_inputs[0])
|
||||
for i, t in enumerate(ep.example_inputs[0]):
|
||||
path = Path(base_directory) / f"{name}_input_{i}.pt"
|
||||
torch.save(t, path)
|
||||
|
||||
cmake_file_str = _get_make_file(package_name, model_names, use_cuda)
|
||||
|
||||
with open(Path(base_directory) / "CMakeLists.txt", "w") as file:
|
||||
file.write(cmake_file_str)
|
||||
|
||||
main_file_str = _get_main_cpp_file(
|
||||
package_name, model_names, use_cuda, example_inputs_map
|
||||
)
|
||||
with open(Path(base_directory) / "main.cpp", "w") as file:
|
||||
file.write(main_file_str)
|
||||
|
|
|
|||
206
torch/export/experimental/_utils.py
Normal file
206
torch/export/experimental/_utils.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
import typing
|
||||
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def _get_main_cpp_file(
|
||||
package_name: str,
|
||||
model_names: list[str],
|
||||
cuda: bool,
|
||||
example_inputs_map: typing.Optional[dict[str, int]],
|
||||
) -> str:
|
||||
"""
|
||||
Generates a main.cpp file for AOTInductor standalone models in the specified package.
|
||||
|
||||
Args:
|
||||
package_name (str): Name of the package containing the models.
|
||||
model_names (List[str]): List of model names to include in the generated main.cpp.
|
||||
cuda (bool): Whether to generate code with CUDA support.
|
||||
example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to
|
||||
its list of example input tensors. If provided, the generated main.cpp will
|
||||
load and run these inputs.
|
||||
|
||||
Returns:
|
||||
str: The contents of the generated main.cpp file as a string.
|
||||
"""
|
||||
|
||||
ib = IndentedBuffer()
|
||||
|
||||
ib.writelines(
|
||||
[
|
||||
"#include <dlfcn.h>",
|
||||
"#include <fstream>",
|
||||
"#include <iostream>",
|
||||
"#include <memory>",
|
||||
"#include <torch/torch.h>",
|
||||
"#include <vector>",
|
||||
"#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>",
|
||||
]
|
||||
)
|
||||
if cuda:
|
||||
ib.writelines(
|
||||
[
|
||||
"#include <cuda.h>",
|
||||
"#include <cuda_runtime_api.h>",
|
||||
]
|
||||
)
|
||||
|
||||
for model_name in model_names:
|
||||
ib.writeline(
|
||||
f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"'
|
||||
)
|
||||
|
||||
ib.newline()
|
||||
for model_name in model_names:
|
||||
ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};")
|
||||
|
||||
ib.writelines(
|
||||
[
|
||||
"using torch::aot_inductor::ConstantHandle;",
|
||||
"using torch::aot_inductor::ConstantMap;",
|
||||
"",
|
||||
"int main(int argc, char* argv[]) {",
|
||||
]
|
||||
)
|
||||
|
||||
with ib.indent():
|
||||
ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";')
|
||||
ib.writeline("try {")
|
||||
|
||||
with ib.indent():
|
||||
ib.writeline("c10::Device device(device_str);")
|
||||
|
||||
if example_inputs_map is not None:
|
||||
# TODO: add device
|
||||
for i, model_name in enumerate(model_names):
|
||||
num_inputs = example_inputs_map[model_name]
|
||||
|
||||
ib.writeline(f"// Load input tensors for model {model_name}")
|
||||
ib.writeline(f"std::vector<at::Tensor> input_tensors{i + 1};")
|
||||
ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{")
|
||||
with ib.indent():
|
||||
ib.writeline(
|
||||
f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";'
|
||||
)
|
||||
ib.writeline("std::ifstream in(filename, std::ios::binary);")
|
||||
ib.writeline("if (!in.is_open()) {")
|
||||
with ib.indent():
|
||||
ib.writeline(
|
||||
'std::cerr << "Failed to open file: " << filename << std::endl;'
|
||||
)
|
||||
ib.writeline("return 1;")
|
||||
ib.writeline("}")
|
||||
ib.writeline(
|
||||
"std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());"
|
||||
)
|
||||
ib.writeline(
|
||||
"torch::IValue ivalue = torch::pickle_load(buffer);"
|
||||
)
|
||||
ib.writeline(
|
||||
f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));"
|
||||
)
|
||||
ib.writeline("}")
|
||||
ib.newline()
|
||||
|
||||
ib.newline()
|
||||
ib.writeline("\n// Create array of input handles")
|
||||
for i in range(len(model_names)):
|
||||
ib.writelines(
|
||||
[
|
||||
f"auto input_handles{i + 1} =",
|
||||
f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});",
|
||||
]
|
||||
)
|
||||
|
||||
ib.writeline("\n// Create array for output handles")
|
||||
for i in range(len(model_names)):
|
||||
ib.writeline(f"AtenTensorHandle output_handle{i + 1};")
|
||||
|
||||
ib.writeline("\n// Create and load models")
|
||||
for i, model_name in enumerate(model_names):
|
||||
ib.writelines(
|
||||
[
|
||||
f"auto constants_map{i + 1} = std::make_shared<ConstantMap>();",
|
||||
f"auto constants_array{i + 1} = std::make_shared<std::vector<ConstantHandle>>();",
|
||||
f"auto model{i + 1} = AOTInductorModel{model_name}::Create(",
|
||||
f" constants_map{i + 1}, constants_array{i + 1}, device_str,",
|
||||
f' "{package_name}/data/aotinductor/{model_name}/");',
|
||||
f"model{i + 1}->load_constants();",
|
||||
]
|
||||
)
|
||||
|
||||
if example_inputs_map is not None:
|
||||
ib.writeline("\n// Run the models")
|
||||
for i in range(len(model_names)):
|
||||
ib.writeline(
|
||||
f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;"
|
||||
)
|
||||
ib.writeline(
|
||||
f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);"
|
||||
)
|
||||
|
||||
ib.writeline("\n// Convert output handles to tensors")
|
||||
for i in range(len(model_names)):
|
||||
ib.writelines(
|
||||
[
|
||||
f"auto output_tensor{i + 1} =",
|
||||
f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);",
|
||||
]
|
||||
)
|
||||
|
||||
ib.writeline("\n// Validate outputs")
|
||||
for i in range(len(model_names)):
|
||||
ib.writeline(
|
||||
f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;"""
|
||||
)
|
||||
|
||||
ib.writeline("return 0;")
|
||||
|
||||
ib.writelines(
|
||||
[
|
||||
"} catch (const std::exception &e) {",
|
||||
]
|
||||
)
|
||||
with ib.indent():
|
||||
ib.writeline('std::cerr << "Error: " << e.what() << std::endl;')
|
||||
ib.writeline("return 1;")
|
||||
|
||||
ib.writeline("}")
|
||||
ib.writeline("}")
|
||||
|
||||
return ib.getvalue()
|
||||
|
||||
|
||||
def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str:
|
||||
ib = IndentedBuffer()
|
||||
|
||||
ib.writelines(
|
||||
[
|
||||
"cmake_minimum_required(VERSION 3.10)",
|
||||
"project(TestProject)",
|
||||
"",
|
||||
"set(CMAKE_CXX_STANDARD 17)",
|
||||
"",
|
||||
"find_package(Torch REQUIRED)",
|
||||
]
|
||||
)
|
||||
if cuda:
|
||||
ib.writeline("find_package(CUDA REQUIRED)")
|
||||
|
||||
ib.newline()
|
||||
for model_name in model_names:
|
||||
ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)")
|
||||
|
||||
ib.writeline("\nadd_executable(main main.cpp)")
|
||||
if cuda:
|
||||
ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)")
|
||||
|
||||
model_libs = " ".join(model_names)
|
||||
ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})")
|
||||
if cuda:
|
||||
ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})")
|
||||
|
||||
return ib.getvalue()
|
||||
Loading…
Reference in New Issue
Block a user