pytorch/torch/_inductor/package/package.py
Maggie Moss eb83c3ca23 Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
2025-10-25 05:32:21 +00:00

139 lines
4.3 KiB
Python

import io
import json
import logging
import os
import tempfile
from typing import IO
import torch
from torch._inductor import config
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export.pt2_archive._package import (
AOTI_FILES,
AOTICompiledModel,
load_pt2,
package_pt2,
)
from torch.types import FileLike
log = logging.getLogger(__name__)
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def get_aoti_file_with_suffix(suffix: str) -> str:
for file in aoti_files:
if file.endswith(suffix):
return file
raise RuntimeError(f"Unable to find file with suffix {suffix}")
# Compile all the files into a .so
cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp"))
consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o"))
file_name = os.path.splitext(cpp_file)[0]
# Parse compile flags and build the .o file
with open(file_name + "_compile_flags.json") as f:
compile_flags = json.load(f)
compile_options = BuildOptionsBase(
**compile_flags, use_relative_path=config.is_fbcode()
)
object_builder = CppBuilder(
name=file_name,
sources=cpp_file,
BuildOption=compile_options,
)
output_o = object_builder.get_target_file_path()
object_builder.build()
# Parse linker flags and build the .so file
with open(file_name + "_linker_flags.json") as f:
linker_flags = json.load(f)
linker_options = BuildOptionsBase(
**linker_flags, use_relative_path=config.is_fbcode()
)
so_builder = CppBuilder(
name=os.path.split(so_path)[-1],
sources=[output_o, consts_o],
BuildOption=linker_options,
output_dir=so_path,
)
output_so = so_builder.get_target_file_path()
so_builder.build()
# mmapped weights
serialized_weights_filename = file_name + "_serialized_weights.bin"
if serialized_weights_filename in aoti_files:
with open(serialized_weights_filename, "rb") as f_weights:
serialized_weights = f_weights.read()
with open(output_so, "a+b") as f_so:
so_size = f_so.tell()
# Page align the weights
f_so.write(b" " * (16384 - so_size % 16384))
f_so.write(serialized_weights)
return output_so
def package_aoti(
archive_file: FileLike,
aoti_files: AOTI_FILES,
) -> FileLike:
"""
Saves the AOTInductor generated files to the PT2Archive format.
Args:
archive_file: The file name to save the package to.
aoti_files: This can either be a singular path to a directory containing
the AOTInductor files, or a dictionary mapping the model name to the
path to its AOTInductor generated files.
"""
return package_pt2(
archive_file,
aoti_files=aoti_files,
)
def load_package(
path: FileLike,
model_name: str = "model",
run_single_threaded: bool = False,
num_runners: int = 1,
device_index: int = -1,
) -> AOTICompiledModel:
try:
pt2_contents = load_pt2(
path,
run_single_threaded=run_single_threaded,
num_runners=num_runners,
device_index=device_index,
)
if model_name not in pt2_contents.aoti_runners:
raise RuntimeError(f"Model {model_name} not found in package")
return pt2_contents.aoti_runners[model_name]
except RuntimeError:
log.warning("Loading outdated pt2 file. Please regenerate your package.")
if isinstance(path, (io.IOBase, IO)):
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
# TODO(angelayi): We shouldn't need to do this -- miniz should
# handle reading the buffer. This is just a temporary workaround
path.seek(0)
f.write(path.read())
log.debug("Writing buffer to tmp file located at %s.", f.name)
loader = torch._C._aoti.AOTIModelPackageLoader(
f.name, model_name, run_single_threaded, num_runners, device_index
)
return AOTICompiledModel(loader)
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
loader = torch._C._aoti.AOTIModelPackageLoader(
path, model_name, run_single_threaded, num_runners, device_index
)
return AOTICompiledModel(loader)