Standalone compile API in _Exporter (#158139)

Given an `package: _ExportPackage`, users can get a ready-to-use workspace in `tmp_dir` by calling:
```python
package._compiled_and_package(
                tmp_dir + "/pt2_pacakge_name.pt2", True, package_example_inputs = True
            )
```

`tmp_dir` will contains:
- `main.cpp` (an example cpp file that create the models, if package_example_inputs is True, it'll also load the example inputs and run the models)
- `CMakeLists.txt`
- `pt2_pacakge_name/` (this is where the models are)
- `pt2_pacakge_name.pt2`
- `inputs.pt` files if package_example_inputs is True

Remaining TODOs
- support loading contants/weights
- the `package_example_inputs = True` option only supports a list of Tensors for now
- eventually we should remove the `torch` dependency, and use `SlimTensor`/`StableIValue` instead.

Test Plan:
```
python test/inductor/test_aot_inductor_package.py  -k test_compile_with_exporter
```

Example generated `main.cpp`:

```cpp
#include <dlfcn.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <torch/torch.h>
#include <vector>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include "package/data/aotinductor/Plus__default/Plus__default.h"
#include "package/data/aotinductor/Minus__default/Minus__default.h"

using torch::aot_inductor::AOTInductorModelPlus__default;
using torch::aot_inductor::AOTInductorModelMinus__default;
using torch::aot_inductor::ConstantHandle;
using torch::aot_inductor::ConstantMap;

int main(int argc, char* argv[]) {
    std::string device_str = "cpu";
    try {
        c10::Device device(device_str);
        // Load input tensors for model Plus__default
        std::vector<at::Tensor> input_tensors1;
        for (int j = 0; j < 2; ++j) {
            std::string filename = "Plus__default_input_" + std::to_string(j) + ".pt";
            std::ifstream in(filename, std::ios::binary);
            if (!in.is_open()) {
                std::cerr << "Failed to open file: " << filename << std::endl;
                return 1;
            }
            std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
            torch::IValue ivalue = torch::pickle_load(buffer);
            input_tensors1.push_back(ivalue.toTensor().to(device));
        }

        // Load input tensors for model Minus__default
        std::vector<at::Tensor> input_tensors2;
        for (int j = 0; j < 2; ++j) {
            std::string filename = "Minus__default_input_" + std::to_string(j) + ".pt";
            std::ifstream in(filename, std::ios::binary);
            if (!in.is_open()) {
                std::cerr << "Failed to open file: " << filename << std::endl;
                return 1;
            }
            std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
            torch::IValue ivalue = torch::pickle_load(buffer);
            input_tensors2.push_back(ivalue.toTensor().to(device));
        }

// Create array of input handles
        auto input_handles1 =
            torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors1);
        auto input_handles2 =
            torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors2);

// Create array for output handles
        AtenTensorHandle output_handle1;
        AtenTensorHandle output_handle2;

// Create and load models
        auto constants_map1 = std::make_shared<ConstantMap>();
        auto constants_array1 = std::make_shared<std::vector<ConstantHandle>>();
        auto model1 = AOTInductorModelPlus__default::Create(
            constants_map1, constants_array1, device_str,
            "package/data/aotinductor/Plus__default/");
        model1->load_constants();
        auto constants_map2 = std::make_shared<ConstantMap>();
        auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
        auto model2 = AOTInductorModelMinus__default::Create(
            constants_map2, constants_array2, device_str,
            "package/data/aotinductor/Minus__default/");
        model2->load_constants();

// Run the models
        torch::aot_inductor::DeviceStreamType stream1 = nullptr;
        model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
        torch::aot_inductor::DeviceStreamType stream2 = nullptr;
        model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);

// Convert output handles to tensors
        auto output_tensor1 =
            torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle1, 1);
        auto output_tensor2 =
            torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle2, 1);

// Validate outputs
        std::cout << "output_tensor1" << output_tensor1 << std::endl;
        std::cout << "output_tensor2" << output_tensor2 << std::endl;
        return 0;
    } catch (const std::exception &e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }
}

```

Rollback Plan:

Differential Revision: D78124705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158139
Approved by: https://github.com/desertfire
This commit is contained in:
Shangdi Yu 2025-07-15 18:47:52 +00:00 committed by PyTorch MergeBot
parent 46915b1361
commit cf3247b74a
4 changed files with 339 additions and 231 deletions

View File

@ -20,6 +20,7 @@ from torch._inductor.package import AOTICompiledModel, load_package, package_aot
from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_cache
from torch.export import Dim
from torch.export.experimental import _ExportPackage
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_utils import (
@ -31,20 +32,6 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
try:
from test_static_linkage_utils import (
get_static_linkage_main_cpp_file,
get_static_linkage_makelist_file_cpu,
get_static_linkage_makelist_file_cuda,
)
except ImportError:
from .test_static_linkage_utils import (
get_static_linkage_main_cpp_file,
get_static_linkage_makelist_file_cpu,
get_static_linkage_makelist_file_cuda,
)
def skipif(predicate: Callable[[str, bool], bool], reason: str):
def decorator(func):
@functools.wraps(func)
@ -153,6 +140,28 @@ class TestAOTInductorPackage(TestCase):
if shutil.which("make") is None:
raise unittest.SkipTest("make is not available")
def cmake_compile_and_run(self, base_dir):
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
build_path = Path(base_dir) / "build"
build_path.mkdir()
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
check=True,
)
subprocess.run(["make"], cwd=build_path, check=True)
result = subprocess.run(
["./build/main"],
cwd=base_dir,
check=True,
capture_output=True,
text=True,
)
return result
def cmake_compile(self, model, example_inputs, options, tmp_dir):
"""
Exports model, compiles it using AOTInductor, extracts the
@ -412,7 +421,7 @@ class TestAOTInductorPackage(TestCase):
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
def test_run_static_linkage_model(self):
def test_compile_with_exporter(self):
self.check_package_cpp_only()
class Model1(torch.nn.Module):
@ -423,64 +432,45 @@ class TestAOTInductorPackage(TestCase):
def forward(self, x, y):
return x - y
def default(*args, **kwargs):
return None
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
torch.ones(3, 3).to(self.device),
torch.ones(3, 3).to(self.device),
)
model1 = Model1().to(self.device)
model2 = Model2().to(self.device)
package = _ExportPackage()
m1 = Model1()
m2 = Model2()
exporter1 = package._exporter("Plus", m1)._define_overload("default", default)
exporter2 = package._exporter("Minus", m2)._define_overload("default", default)
exporter1(*example_inputs)
exporter2(*example_inputs)
models = [model1, model2]
i = 0
model_names = ["Plus", "Minus"]
with (
tempfile.TemporaryDirectory() as tmp_dir,
):
for i in range(2):
model = models[i]
# TODO: should be done through _ExportPackage
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(
ep,
inductor_configs={
"aot_inductor.compile_standalone": True,
"always_keep_tensor_constants": True,
"aot_inductor.model_name_for_generated_files": model_names[i],
},
for package_example_inputs in [True, False]:
with (
tempfile.TemporaryDirectory() as tmp_dir,
):
package._compiled_and_package(
tmp_dir + "/package.pt2", True, package_example_inputs
)
with (
zipfile.ZipFile(package_path, "r") as zip_ref,
):
zip_ref.extractall(tmp_dir)
file_str = get_static_linkage_main_cpp_file()
with open(Path(tmp_dir) / "main.cpp", "w") as f:
f.write(file_str)
if self.device == GPU_TYPE:
cmake_file_str = get_static_linkage_makelist_file_cuda()
else:
cmake_file_str = get_static_linkage_makelist_file_cpu()
with open(Path(tmp_dir) / "CMakeLists.txt", "w") as f:
f.write(cmake_file_str)
build_path = Path(tmp_dir) / "build"
build_path.mkdir()
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
)
subprocess.run(["make"], cwd=build_path, check=True)
subprocess.run(
["./main", f"{tmp_dir}/", self.device], cwd=build_path, check=True
)
# Test compiling generated files
result = self.cmake_compile_and_run(tmp_dir)
if package_example_inputs:
if self.device == GPU_TYPE:
self.assertEqual(
result.stdout,
"output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2 0 0 0\n"
" 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n",
)
else:
self.assertEqual(
result.stdout,
"output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2 0 0 0\n"
" 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n",
)
def test_metadata(self):
class Model(torch.nn.Module):

View File

@ -1,157 +0,0 @@
# Owner(s): ["module: inductor"]
from torch.testing._internal.common_utils import run_tests
def get_static_linkage_main_cpp_file():
return """
#include <dlfcn.h>
#include <iostream>
#include <memory>
#include <torch/torch.h>
#include <vector>
#include <cuda.h>
#include <cuda_runtime_api.h>
// Include the AOTInductor headers
#include "Minus.wrapper/data/aotinductor/model/Minus.h"
#include "Plus.wrapper/data/aotinductor/model/Plus.h"
#include <torch/csrc/inductor/aoti_runtime/model_container.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
using torch::aot_inductor::AOTInductorModelMinus;
using torch::aot_inductor::AOTInductorModelPlus;
using torch::aot_inductor::ConstantHandle;
using torch::aot_inductor::ConstantMap;
int main(int argc, char* argv[]) {
if (argc < 2) {
std::cerr
<< "Usage: ./main <path> <device>"
<< std::endl;
return 1;
}
std::string path = argv[1];
std::string device_str = argv[2];
try {
torch::Device device(device_str);
// Create two input tensors (10x10)
auto tensor1 = torch::ones({10, 10}, device);
auto tensor2 = torch::ones({10, 10}, device);
// Create two input tensors (10x10)
auto tensor3 = torch::ones({10, 10}, device);
auto tensor4 = torch::ones({10, 10}, device);
std::vector<at::Tensor> input_tensors = {tensor1, tensor2};
std::vector<at::Tensor> input_tensors2 = {tensor3, tensor4};
// Create array of input handles
auto input_handles1 =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
input_tensors);
auto input_handles2 =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(
input_tensors2);
// Create array for output handle
AtenTensorHandle output_handle1;
AtenTensorHandle output_handle2;
auto constants_map = std::make_shared<ConstantMap>();
auto constants_array = std::make_shared<std::vector<ConstantHandle>>();
auto model1 = AOTInductorModelPlus::Create(
constants_map, constants_array, device_str,
path + "Plus.wrapper/data/"
"aotinductor/model/");
model1->load_constants();
auto constants_map2 = std::make_shared<ConstantMap>();
auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
auto model2 = AOTInductorModelMinus::Create(
constants_map2, constants_array2, device_str,
path + "Minus.wrapper/data/"
"aotinductor/model/");
model2->load_constants();
// Run the model
torch::aot_inductor::DeviceStreamType stream1 = nullptr;
torch::aot_inductor::DeviceStreamType stream2 = nullptr;
model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);
// Convert output handle to tensor
auto output_tensor1 =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
&output_handle1, 1);
auto output_tensor2 =
torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
&output_handle2, 1);
if (!(torch::all(output_tensor1[0] == 2).item<bool>())){
std::cout << "Wrong Output for Plus Model: " << output_tensor1 << std::endl;
throw std::runtime_error("Tensor does not contain only the expected value 2.");
}
if (!(torch::all(output_tensor2[0] == 0).item<bool>())){
std::cout << "Wrong Output for Minus Model: " << output_tensor1 << std::endl;
throw std::runtime_error("Tensor does not contain only the expected value 0.");
}
return 0;
} catch (const std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
}
"""
def get_static_linkage_makelist_file_cuda():
return """
cmake_minimum_required(VERSION 3.10)
project(TestProject)
set(CMAKE_CXX_STANDARD 17)
find_package(Torch REQUIRED)
find_package(CUDA REQUIRED)
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
# Create executable
add_executable(main main.cpp)
target_compile_definitions(main PRIVATE USE_CUDA)
target_link_libraries(main PRIVATE torch cuda
${CUDA_LIBRARIES}
Plus
Minus)
"""
def get_static_linkage_makelist_file_cpu():
return """
cmake_minimum_required(VERSION 3.10)
project(TestProject)
set(CMAKE_CXX_STANDARD 17)
find_package(Torch REQUIRED)
add_subdirectory(Plus.wrapper/data/aotinductor/model/)
add_subdirectory(Minus.wrapper/data/aotinductor/model/)
# Create executable
add_executable(main main.cpp)
target_link_libraries(main PRIVATE torch
Plus
Minus)
"""
if __name__ == "__main__":
run_tests()

View File

@ -1,14 +1,22 @@
import copy
import dataclasses
import functools
import os
import tempfile
import types
import typing
import typing_extensions
import zipfile
from pathlib import Path
import torch
from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file
from torch.export.exported_program import _decompose_exported_program
__all__ = [] # type: ignore[var-annotated]
def _copy_graph_module_and_signature(
ep: torch.fx.GraphModule,
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
@ -333,18 +341,79 @@ class _ExportPackage:
for overload, ep in method_data.overloads.items():
yield f"{method}:{overload}", ep
def _compiled_and_package(self, f: torch.types.FileLike) -> None:
options = {
def _compiled_and_package(
self,
f: torch.types.FileLike,
standalone: bool = False,
package_example_inputs: bool = False,
) -> None:
options: dict[str, typing.Any] = {
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": True,
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.compile_standalone": standalone,
}
weights_map = {}
aoti_files_map = {}
model_names = []
for name, ep in self._method_overloads:
weights = torch._inductor.aot_compile(ep.module(), (), options=options) # type: ignore[arg-type]
weights_map[name] = weights
torch._inductor.package.package.package_aoti(
name = name.replace(":", "__")
model_names.append(name)
options["aot_inductor.model_name_for_generated_files"] = name
aoti_files = torch._inductor.aot_compile(
ep.module(), # type: ignore[arg-type]
ep.example_inputs[0],
kwargs=ep.example_inputs[1],
options=options,
)
aoti_files_map[name] = aoti_files
from torch._inductor.package import package
pt2_path = package.package_aoti(
f,
weights_map, # type: ignore[arg-type]
aoti_files_map, # type: ignore[arg-type]
)
if not standalone:
return
assert isinstance(pt2_path, str)
base_directory = os.path.dirname(pt2_path)
package_name = os.path.basename(pt2_path)[:-4]
with (
zipfile.ZipFile(pt2_path, "r") as zip_ref,
):
zip_ref.extractall(base_directory)
example_inputs_map: typing.Optional[dict[str, int]] = (
{} if package_example_inputs else None
)
use_cuda = False
for name, ep in self._method_overloads:
name = name.replace(":", "__")
# TODO: also dump kwargs
# TODO: currently only support list of Tensors and they need to be on the same device
if not ep.example_inputs:
continue
for inp in ep.example_inputs[0]:
if isinstance(inp, torch.Tensor) and inp.device.type == "cuda":
# TODO: more carefully determine the device type
use_cuda = True
if package_example_inputs:
assert example_inputs_map is not None
example_inputs_map[name] = len(ep.example_inputs[0])
for i, t in enumerate(ep.example_inputs[0]):
path = Path(base_directory) / f"{name}_input_{i}.pt"
torch.save(t, path)
cmake_file_str = _get_make_file(package_name, model_names, use_cuda)
with open(Path(base_directory) / "CMakeLists.txt", "w") as file:
file.write(cmake_file_str)
main_file_str = _get_main_cpp_file(
package_name, model_names, use_cuda, example_inputs_map
)
with open(Path(base_directory) / "main.cpp", "w") as file:
file.write(main_file_str)

View File

@ -0,0 +1,206 @@
import typing
from torch._inductor.utils import IndentedBuffer
__all__ = [] # type: ignore[var-annotated]
def _get_main_cpp_file(
package_name: str,
model_names: list[str],
cuda: bool,
example_inputs_map: typing.Optional[dict[str, int]],
) -> str:
"""
Generates a main.cpp file for AOTInductor standalone models in the specified package.
Args:
package_name (str): Name of the package containing the models.
model_names (List[str]): List of model names to include in the generated main.cpp.
cuda (bool): Whether to generate code with CUDA support.
example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to
its list of example input tensors. If provided, the generated main.cpp will
load and run these inputs.
Returns:
str: The contents of the generated main.cpp file as a string.
"""
ib = IndentedBuffer()
ib.writelines(
[
"#include <dlfcn.h>",
"#include <fstream>",
"#include <iostream>",
"#include <memory>",
"#include <torch/torch.h>",
"#include <vector>",
"#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>",
]
)
if cuda:
ib.writelines(
[
"#include <cuda.h>",
"#include <cuda_runtime_api.h>",
]
)
for model_name in model_names:
ib.writeline(
f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"'
)
ib.newline()
for model_name in model_names:
ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};")
ib.writelines(
[
"using torch::aot_inductor::ConstantHandle;",
"using torch::aot_inductor::ConstantMap;",
"",
"int main(int argc, char* argv[]) {",
]
)
with ib.indent():
ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";')
ib.writeline("try {")
with ib.indent():
ib.writeline("c10::Device device(device_str);")
if example_inputs_map is not None:
# TODO: add device
for i, model_name in enumerate(model_names):
num_inputs = example_inputs_map[model_name]
ib.writeline(f"// Load input tensors for model {model_name}")
ib.writeline(f"std::vector<at::Tensor> input_tensors{i + 1};")
ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{")
with ib.indent():
ib.writeline(
f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";'
)
ib.writeline("std::ifstream in(filename, std::ios::binary);")
ib.writeline("if (!in.is_open()) {")
with ib.indent():
ib.writeline(
'std::cerr << "Failed to open file: " << filename << std::endl;'
)
ib.writeline("return 1;")
ib.writeline("}")
ib.writeline(
"std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());"
)
ib.writeline(
"torch::IValue ivalue = torch::pickle_load(buffer);"
)
ib.writeline(
f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));"
)
ib.writeline("}")
ib.newline()
ib.newline()
ib.writeline("\n// Create array of input handles")
for i in range(len(model_names)):
ib.writelines(
[
f"auto input_handles{i + 1} =",
f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});",
]
)
ib.writeline("\n// Create array for output handles")
for i in range(len(model_names)):
ib.writeline(f"AtenTensorHandle output_handle{i + 1};")
ib.writeline("\n// Create and load models")
for i, model_name in enumerate(model_names):
ib.writelines(
[
f"auto constants_map{i + 1} = std::make_shared<ConstantMap>();",
f"auto constants_array{i + 1} = std::make_shared<std::vector<ConstantHandle>>();",
f"auto model{i + 1} = AOTInductorModel{model_name}::Create(",
f" constants_map{i + 1}, constants_array{i + 1}, device_str,",
f' "{package_name}/data/aotinductor/{model_name}/");',
f"model{i + 1}->load_constants();",
]
)
if example_inputs_map is not None:
ib.writeline("\n// Run the models")
for i in range(len(model_names)):
ib.writeline(
f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;"
)
ib.writeline(
f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);"
)
ib.writeline("\n// Convert output handles to tensors")
for i in range(len(model_names)):
ib.writelines(
[
f"auto output_tensor{i + 1} =",
f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);",
]
)
ib.writeline("\n// Validate outputs")
for i in range(len(model_names)):
ib.writeline(
f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;"""
)
ib.writeline("return 0;")
ib.writelines(
[
"} catch (const std::exception &e) {",
]
)
with ib.indent():
ib.writeline('std::cerr << "Error: " << e.what() << std::endl;')
ib.writeline("return 1;")
ib.writeline("}")
ib.writeline("}")
return ib.getvalue()
def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str:
ib = IndentedBuffer()
ib.writelines(
[
"cmake_minimum_required(VERSION 3.10)",
"project(TestProject)",
"",
"set(CMAKE_CXX_STANDARD 17)",
"",
"find_package(Torch REQUIRED)",
]
)
if cuda:
ib.writeline("find_package(CUDA REQUIRED)")
ib.newline()
for model_name in model_names:
ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)")
ib.writeline("\nadd_executable(main main.cpp)")
if cuda:
ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)")
model_libs = " ".join(model_names)
ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})")
if cuda:
ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})")
return ib.getvalue()