mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean. test plan: `lintrunner -a` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178 Approved by: https://github.com/oulgen
139 lines
4.3 KiB
Python
139 lines
4.3 KiB
Python
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from typing import IO
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
|
from torch.export.pt2_archive._package import (
|
|
AOTI_FILES,
|
|
AOTICompiledModel,
|
|
load_pt2,
|
|
package_pt2,
|
|
)
|
|
from torch.types import FileLike
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
|
def get_aoti_file_with_suffix(suffix: str) -> str:
|
|
for file in aoti_files:
|
|
if file.endswith(suffix):
|
|
return file
|
|
raise RuntimeError(f"Unable to find file with suffix {suffix}")
|
|
|
|
# Compile all the files into a .so
|
|
cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
|
|
consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
|
|
|
|
file_name = os.path.splitext(cpp_file)[0]
|
|
|
|
# Parse compile flags and build the .o file
|
|
with open(file_name + "_compile_flags.json") as f:
|
|
compile_flags = json.load(f)
|
|
|
|
compile_options = BuildOptionsBase(
|
|
**compile_flags, use_relative_path=config.is_fbcode()
|
|
)
|
|
object_builder = CppBuilder(
|
|
name=file_name,
|
|
sources=cpp_file,
|
|
BuildOption=compile_options,
|
|
)
|
|
output_o = object_builder.get_target_file_path()
|
|
object_builder.build()
|
|
|
|
# Parse linker flags and build the .so file
|
|
with open(file_name + "_linker_flags.json") as f:
|
|
linker_flags = json.load(f)
|
|
|
|
linker_options = BuildOptionsBase(
|
|
**linker_flags, use_relative_path=config.is_fbcode()
|
|
)
|
|
so_builder = CppBuilder(
|
|
name=os.path.split(so_path)[-1],
|
|
sources=[output_o, consts_o],
|
|
BuildOption=linker_options,
|
|
output_dir=so_path,
|
|
)
|
|
output_so = so_builder.get_target_file_path()
|
|
so_builder.build()
|
|
|
|
# mmapped weights
|
|
serialized_weights_filename = file_name + "_serialized_weights.bin"
|
|
if serialized_weights_filename in aoti_files:
|
|
with open(serialized_weights_filename, "rb") as f_weights:
|
|
serialized_weights = f_weights.read()
|
|
|
|
with open(output_so, "a+b") as f_so:
|
|
so_size = f_so.tell()
|
|
# Page align the weights
|
|
f_so.write(b" " * (16384 - so_size % 16384))
|
|
f_so.write(serialized_weights)
|
|
|
|
return output_so
|
|
|
|
|
|
def package_aoti(
|
|
archive_file: FileLike,
|
|
aoti_files: AOTI_FILES,
|
|
) -> FileLike:
|
|
"""
|
|
Saves the AOTInductor generated files to the PT2Archive format.
|
|
|
|
Args:
|
|
archive_file: The file name to save the package to.
|
|
aoti_files: This can either be a singular path to a directory containing
|
|
the AOTInductor files, or a dictionary mapping the model name to the
|
|
path to its AOTInductor generated files.
|
|
"""
|
|
|
|
return package_pt2(
|
|
archive_file,
|
|
aoti_files=aoti_files,
|
|
)
|
|
|
|
|
|
def load_package(
|
|
path: FileLike,
|
|
model_name: str = "model",
|
|
run_single_threaded: bool = False,
|
|
num_runners: int = 1,
|
|
device_index: int = -1,
|
|
) -> AOTICompiledModel:
|
|
try:
|
|
pt2_contents = load_pt2(
|
|
path,
|
|
run_single_threaded=run_single_threaded,
|
|
num_runners=num_runners,
|
|
device_index=device_index,
|
|
)
|
|
if model_name not in pt2_contents.aoti_runners:
|
|
raise RuntimeError(f"Model {model_name} not found in package")
|
|
return pt2_contents.aoti_runners[model_name]
|
|
except RuntimeError:
|
|
log.warning("Loading outdated pt2 file. Please regenerate your package.")
|
|
|
|
if isinstance(path, (io.IOBase, IO)):
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
|
# TODO(angelayi): We shouldn't need to do this -- miniz should
|
|
# handle reading the buffer. This is just a temporary workaround
|
|
path.seek(0)
|
|
f.write(path.read())
|
|
log.debug("Writing buffer to tmp file located at %s.", f.name)
|
|
loader = torch._C._aoti.AOTIModelPackageLoader(
|
|
f.name, model_name, run_single_threaded, num_runners, device_index
|
|
)
|
|
return AOTICompiledModel(loader)
|
|
|
|
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
|
|
loader = torch._C._aoti.AOTIModelPackageLoader(
|
|
path, model_name, run_single_threaded, num_runners, device_index
|
|
)
|
|
return AOTICompiledModel(loader)
|