pytorch/torch/export/experimental/_utils.py
Shangdi Yu cf3247b74a Standalone compile API in _Exporter (#158139)
Given an `package: _ExportPackage`, users can get a ready-to-use workspace in `tmp_dir` by calling:
```python
package._compiled_and_package(
                tmp_dir + "/pt2_pacakge_name.pt2", True, package_example_inputs = True
            )
```

`tmp_dir` will contains:
- `main.cpp` (an example cpp file that create the models, if package_example_inputs is True, it'll also load the example inputs and run the models)
- `CMakeLists.txt`
- `pt2_pacakge_name/` (this is where the models are)
- `pt2_pacakge_name.pt2`
- `inputs.pt` files if package_example_inputs is True

Remaining TODOs
- support loading contants/weights
- the `package_example_inputs = True` option only supports a list of Tensors for now
- eventually we should remove the `torch` dependency, and use `SlimTensor`/`StableIValue` instead.

Test Plan:
```
python test/inductor/test_aot_inductor_package.py  -k test_compile_with_exporter
```

Example generated `main.cpp`:

```cpp
#include <dlfcn.h>
#include <fstream>
#include <iostream>
#include <memory>
#include <torch/torch.h>
#include <vector>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include "package/data/aotinductor/Plus__default/Plus__default.h"
#include "package/data/aotinductor/Minus__default/Minus__default.h"

using torch::aot_inductor::AOTInductorModelPlus__default;
using torch::aot_inductor::AOTInductorModelMinus__default;
using torch::aot_inductor::ConstantHandle;
using torch::aot_inductor::ConstantMap;

int main(int argc, char* argv[]) {
    std::string device_str = "cpu";
    try {
        c10::Device device(device_str);
        // Load input tensors for model Plus__default
        std::vector<at::Tensor> input_tensors1;
        for (int j = 0; j < 2; ++j) {
            std::string filename = "Plus__default_input_" + std::to_string(j) + ".pt";
            std::ifstream in(filename, std::ios::binary);
            if (!in.is_open()) {
                std::cerr << "Failed to open file: " << filename << std::endl;
                return 1;
            }
            std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
            torch::IValue ivalue = torch::pickle_load(buffer);
            input_tensors1.push_back(ivalue.toTensor().to(device));
        }

        // Load input tensors for model Minus__default
        std::vector<at::Tensor> input_tensors2;
        for (int j = 0; j < 2; ++j) {
            std::string filename = "Minus__default_input_" + std::to_string(j) + ".pt";
            std::ifstream in(filename, std::ios::binary);
            if (!in.is_open()) {
                std::cerr << "Failed to open file: " << filename << std::endl;
                return 1;
            }
            std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
            torch::IValue ivalue = torch::pickle_load(buffer);
            input_tensors2.push_back(ivalue.toTensor().to(device));
        }

// Create array of input handles
        auto input_handles1 =
            torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors1);
        auto input_handles2 =
            torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors2);

// Create array for output handles
        AtenTensorHandle output_handle1;
        AtenTensorHandle output_handle2;

// Create and load models
        auto constants_map1 = std::make_shared<ConstantMap>();
        auto constants_array1 = std::make_shared<std::vector<ConstantHandle>>();
        auto model1 = AOTInductorModelPlus__default::Create(
            constants_map1, constants_array1, device_str,
            "package/data/aotinductor/Plus__default/");
        model1->load_constants();
        auto constants_map2 = std::make_shared<ConstantMap>();
        auto constants_array2 = std::make_shared<std::vector<ConstantHandle>>();
        auto model2 = AOTInductorModelMinus__default::Create(
            constants_map2, constants_array2, device_str,
            "package/data/aotinductor/Minus__default/");
        model2->load_constants();

// Run the models
        torch::aot_inductor::DeviceStreamType stream1 = nullptr;
        model1->run(&input_handles1[0], &output_handle1, stream1, nullptr);
        torch::aot_inductor::DeviceStreamType stream2 = nullptr;
        model2->run(&input_handles2[0], &output_handle2, stream2, nullptr);

// Convert output handles to tensors
        auto output_tensor1 =
            torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle1, 1);
        auto output_tensor2 =
            torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle2, 1);

// Validate outputs
        std::cout << "output_tensor1" << output_tensor1 << std::endl;
        std::cout << "output_tensor2" << output_tensor2 << std::endl;
        return 0;
    } catch (const std::exception &e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }
}

```

Rollback Plan:

Differential Revision: D78124705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158139
Approved by: https://github.com/desertfire
2025-07-15 18:47:56 +00:00

207 lines
7.5 KiB
Python

import typing
from torch._inductor.utils import IndentedBuffer
__all__ = [] # type: ignore[var-annotated]
def _get_main_cpp_file(
package_name: str,
model_names: list[str],
cuda: bool,
example_inputs_map: typing.Optional[dict[str, int]],
) -> str:
"""
Generates a main.cpp file for AOTInductor standalone models in the specified package.
Args:
package_name (str): Name of the package containing the models.
model_names (List[str]): List of model names to include in the generated main.cpp.
cuda (bool): Whether to generate code with CUDA support.
example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to
its list of example input tensors. If provided, the generated main.cpp will
load and run these inputs.
Returns:
str: The contents of the generated main.cpp file as a string.
"""
ib = IndentedBuffer()
ib.writelines(
[
"#include <dlfcn.h>",
"#include <fstream>",
"#include <iostream>",
"#include <memory>",
"#include <torch/torch.h>",
"#include <vector>",
"#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>",
]
)
if cuda:
ib.writelines(
[
"#include <cuda.h>",
"#include <cuda_runtime_api.h>",
]
)
for model_name in model_names:
ib.writeline(
f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"'
)
ib.newline()
for model_name in model_names:
ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};")
ib.writelines(
[
"using torch::aot_inductor::ConstantHandle;",
"using torch::aot_inductor::ConstantMap;",
"",
"int main(int argc, char* argv[]) {",
]
)
with ib.indent():
ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";')
ib.writeline("try {")
with ib.indent():
ib.writeline("c10::Device device(device_str);")
if example_inputs_map is not None:
# TODO: add device
for i, model_name in enumerate(model_names):
num_inputs = example_inputs_map[model_name]
ib.writeline(f"// Load input tensors for model {model_name}")
ib.writeline(f"std::vector<at::Tensor> input_tensors{i + 1};")
ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{")
with ib.indent():
ib.writeline(
f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";'
)
ib.writeline("std::ifstream in(filename, std::ios::binary);")
ib.writeline("if (!in.is_open()) {")
with ib.indent():
ib.writeline(
'std::cerr << "Failed to open file: " << filename << std::endl;'
)
ib.writeline("return 1;")
ib.writeline("}")
ib.writeline(
"std::vector<char> buffer((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());"
)
ib.writeline(
"torch::IValue ivalue = torch::pickle_load(buffer);"
)
ib.writeline(
f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));"
)
ib.writeline("}")
ib.newline()
ib.newline()
ib.writeline("\n// Create array of input handles")
for i in range(len(model_names)):
ib.writelines(
[
f"auto input_handles{i + 1} =",
f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});",
]
)
ib.writeline("\n// Create array for output handles")
for i in range(len(model_names)):
ib.writeline(f"AtenTensorHandle output_handle{i + 1};")
ib.writeline("\n// Create and load models")
for i, model_name in enumerate(model_names):
ib.writelines(
[
f"auto constants_map{i + 1} = std::make_shared<ConstantMap>();",
f"auto constants_array{i + 1} = std::make_shared<std::vector<ConstantHandle>>();",
f"auto model{i + 1} = AOTInductorModel{model_name}::Create(",
f" constants_map{i + 1}, constants_array{i + 1}, device_str,",
f' "{package_name}/data/aotinductor/{model_name}/");',
f"model{i + 1}->load_constants();",
]
)
if example_inputs_map is not None:
ib.writeline("\n// Run the models")
for i in range(len(model_names)):
ib.writeline(
f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;"
)
ib.writeline(
f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);"
)
ib.writeline("\n// Convert output handles to tensors")
for i in range(len(model_names)):
ib.writelines(
[
f"auto output_tensor{i + 1} =",
f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);",
]
)
ib.writeline("\n// Validate outputs")
for i in range(len(model_names)):
ib.writeline(
f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;"""
)
ib.writeline("return 0;")
ib.writelines(
[
"} catch (const std::exception &e) {",
]
)
with ib.indent():
ib.writeline('std::cerr << "Error: " << e.what() << std::endl;')
ib.writeline("return 1;")
ib.writeline("}")
ib.writeline("}")
return ib.getvalue()
def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str:
ib = IndentedBuffer()
ib.writelines(
[
"cmake_minimum_required(VERSION 3.10)",
"project(TestProject)",
"",
"set(CMAKE_CXX_STANDARD 17)",
"",
"find_package(Torch REQUIRED)",
]
)
if cuda:
ib.writeline("find_package(CUDA REQUIRED)")
ib.newline()
for model_name in model_names:
ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)")
ib.writeline("\nadd_executable(main main.cpp)")
if cuda:
ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)")
model_libs = " ".join(model_names)
ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})")
if cuda:
ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})")
return ib.getvalue()