Revert "[export] Move PT2ArchiveWriter/Reader to torch/export (#153795)"

This reverts commit 7e80f23516.

Reverted https://github.com/pytorch/pytorch/pull/153795 on behalf of https://github.com/malfet due to Looks like it broke lots of tests, see ec368a1903/1 ([comment](https://github.com/pytorch/pytorch/pull/153795#issuecomment-2905415496))
This commit is contained in:
PyTorch MergeBot 2025-05-23 18:29:08 +00:00
parent ec368a1903
commit 4ff19ecf66
6 changed files with 85 additions and 231 deletions

View File

@ -2650,11 +2650,6 @@
"torch.export.graph_signature": [
"TokenArgument"
],
"torch.export.pt2_archive": [
"PT2ArchiveWriter",
"PT2ArchiveReader",
"is_pt2_package"
],
"torch.fx.experimental.shape_inference.infer_shape": [
"DimDynamic",
"FakeTensorMode",

View File

@ -216,10 +216,8 @@ class TestAOTInductorPackage(TestCase):
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
package_path, "r"
) as zip_ref:
filenames = zip_ref.namelist()
prefix = filenames[0].split("/")[0]
zip_ref.extractall(tmp_dir)
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists())
if self.device == GPU_TYPE:
kernel_bin = get_kernel_bin_format(self.device)

View File

@ -3,7 +3,10 @@ import json
import logging
import os
import tempfile
import zipfile
from pathlib import Path
from typing import Any, IO, Optional, Union
from typing_extensions import Self
import torch
import torch._inductor
@ -11,9 +14,9 @@ import torch.utils._pytree as pytree
from torch._inductor import config
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs
from torch.export.pt2_archive._package import PT2ArchiveWriter
from torch.export.pt2_archive.constants import (
AOTINDUCTOR_DIR,
ARCHIVE_VERSION_VALUE,
CONSTANTS_DIR,
CUSTOM_OBJ_FILENAME_PREFIX,
)
@ -23,6 +26,74 @@ from torch.types import FileLike
log = logging.getLogger(__name__)
class PT2ArchiveWriter:
def __init__(self, archive_path: FileLike) -> None:
self.archive_path: FileLike = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> Self:
assert self.archive_file is None
self.archive_file = zipfile.ZipFile(
self.archive_path, "w", compression=zipfile.ZIP_STORED
)
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
self.writestr("archive_format", "pt2")
return self
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
assert self.archive_file is not None
self.archive_file.close()
self.archive_file = None
return None
def writestr(self, name: str, data: Union[bytes, str]) -> None:
assert self.archive_file is not None
self.archive_file.writestr(name, data)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
assert self.archive_file is not None
self.archive_file.write(file_path, arcname=name)
class PT2ArchiveReader:
def __init__(self, archive_path: str) -> None:
self.archive_path: str = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> Self:
self.archive_file = zipfile.ZipFile(
self.archive_path, "r", compression=zipfile.ZIP_STORED
)
return self
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
if self.archive_file is not None:
self.archive_file.close()
return None
def read(self, name: str) -> bytes:
assert self.archive_file is not None
return self.archive_file.read(name)
def extract_to_path(self, member: str, path: str) -> str:
assert self.archive_file is not None
return self.archive_file.extract(member, path)
def extractall(self, path: str) -> None:
assert self.archive_file is not None
self.archive_file.extractall(path)
def get_file_names(self) -> list[str]:
assert self.archive_file is not None
return self.archive_file.namelist()
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def get_aoti_file_with_suffix(suffix: str) -> str:
for file in aoti_files:

View File

@ -367,7 +367,15 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
}
std::vector<std::string> found_filenames;
temp_dir_ = create_temp_dir();
std::string so_filename;
std::string cpp_filename;
std::vector<std::string> obj_filenames;
std::string found_filenames; // Saving for bookkeeping
std::string model_directory =
"data" + k_separator + "aotinductor" + k_separator + model_name;
std::string const_directory = "data" + k_separator + "constants";
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
uint32_t filename_len =
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
@ -381,40 +389,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
&zip_archive, i, filename_str.data(), filename_len)) {
throw std::runtime_error("Failed to read filename");
}
found_filenames.push_back(filename_str);
}
if (found_filenames.empty()) {
throw std::runtime_error("No files found in zip archive.");
}
found_filenames += filename_str;
found_filenames += " ";
// All the paths are prepended with a tmp/ directory. We need to find the
// prefix.
std::string file_prefix;
size_t pos = found_filenames[0].find('/');
std::string prefix0 = found_filenames[0].substr(0, pos);
pos = found_filenames[1].find('/');
std::string prefix1 = found_filenames[1].substr(0, pos);
if (!prefix0.empty() && !prefix1.empty() && prefix0 == prefix1) {
file_prefix = prefix0 + "/";
} else {
LOG(WARNING)
<< "You are using an outdated version of the pt2 archive which do not have a prefix in front of each filename. Example: \n"
<< found_filenames[0] << "\n"
<< found_filenames[1];
}
temp_dir_ = create_temp_dir();
std::string so_filename;
std::string cpp_filename;
std::vector<std::string> obj_filenames;
std::string model_directory = file_prefix + "data" + k_separator +
"aotinductor" + k_separator + model_name;
std::string const_directory = "data" + k_separator + "constants";
for (const std::string& filename_str : found_filenames) {
// Only compile files in the specified model directory
if (c10::starts_with(filename_str, model_directory) ||
c10::starts_with(filename_str, const_directory)) {
@ -482,13 +460,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
}
if (cpp_filename.empty() && so_filename.empty()) {
std::string found_filenames_str;
for (const std::string& filename : found_filenames) {
found_filenames_str += filename + "\n";
}
throw std::runtime_error(
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
found_filenames_str);
found_filenames);
}
// Compile the .so

View File

@ -1,4 +0,0 @@
from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter
__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"]

View File

@ -1,180 +0,0 @@
# pyre-unsafe
import glob
import io
import logging
import os
import zipfile
from typing import Any, Union
import torch
from torch.export.pt2_archive.constants import (
ARCHIVE_FORMAT_PATH,
ARCHIVE_FORMAT_VALUE,
ARCHIVE_VERSION_PATH,
ARCHIVE_VERSION_VALUE,
)
from torch.types import FileLike
logger: logging.Logger = logging.getLogger(__name__)
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
"""
Check if the serialized model is a PT2 Archive package.
"""
try:
zip_reader = zipfile.ZipFile(
io.BytesIO(serialized_model)
if isinstance(serialized_model, bytes)
else serialized_model
)
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
if archive_format_path in zip_reader.namelist():
return zip_reader.read(archive_format_path) == b"pt2"
except Exception as ex:
logger.info("Model is not a PT2 package: %s", str(ex))
return False
class PT2ArchiveWriter:
"""
Context manager for writing a PT2 archive.
"""
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
# NOTICE: version here is different from the archive_version
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
# archive_version is the version of the PT2 archive spec, which write to /archive_version
self.archive_file.set_min_version(6)
def __enter__(self) -> "PT2ArchiveWriter":
return self
def __exit__(self, *args: Any) -> None:
if not self.has_record(ARCHIVE_FORMAT_PATH):
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
if not self.has_record(ARCHIVE_VERSION_PATH):
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
self.close()
def has_record(self, name: str) -> bool:
"""
Check if a record exists in the archive.
"""
return name in self.archive_file.get_all_written_records()
def count_prefix(self, prefix: str) -> int:
"""
Count the number of records that start with a given prefix.
"""
return sum(
1
for record in self.archive_file.get_all_written_records()
if record.startswith(prefix)
)
def write_bytes(self, name: str, data: bytes) -> None:
"""
Write a bytes object to the archive.
name: The destination file inside the archive.
data: The bytes object to write.
"""
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
self.archive_file.write_record(name, data, len(data))
def write_string(self, name: str, data: str) -> None:
"""
Write a string object to the archive.
name: The destination file inside the archive.
data: The string object to write.
"""
assert isinstance(data, str), f"Expected string but got {type(data)}"
data_bytes = data.encode()
self.write_bytes(name, data_bytes)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
with open(file_path, "rb") as f:
file_bytes = f.read()
self.write_bytes(name, file_bytes)
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
"""
Copy a folder into the archive.
archive_dir: The destination folder inside the archive.
folder_dir: The source folder on disk.
"""
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
file_paths = filter(
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
self.write_file(archive_path, file_path)
def close(self) -> None:
"""
Close the archive.
"""
self.archive_file.write_end_of_file()
class PT2ArchiveReader:
"""
Context manager for reading a PT2 archive.
"""
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
assert (
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
), "Invalid archive format"
def __enter__(self) -> "PT2ArchiveReader":
return self
def __exit__(self, *args: Any) -> None:
# torch._C.PyTorchFileReader doesn't have a close method
pass
def read_bytes(self, name: str) -> bytes:
"""
Read a bytes object from the archive.
name: The source file inside the archive.
"""
return self.archive_file.get_record(name)
def read_string(self, name: str) -> str:
"""
Read a string object from the archive.
name: The source file inside the archive.
"""
data = self.read_bytes(name)
return data.decode()
def archive_version(self) -> int:
"""
Get the archive version.
"""
try:
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
except Exception:
# if archive_version is not found, it means the archive is older than version 0.
# In this case, we assume the archive is version 0.
archive_version = "0"
return int(archive_version)