Revert "[export] Move PT2ArchiveWriter/Reader to torch/export (#153795)"

This reverts commit 7e80f23516.

Reverted https://github.com/pytorch/pytorch/pull/153795 on behalf of https://github.com/malfet due to Looks like it broke lots of tests, see ec368a1903/1 ([comment](https://github.com/pytorch/pytorch/pull/153795#issuecomment-2905415496))
This commit is contained in:
PyTorch MergeBot 2025-05-23 18:29:08 +00:00
parent ec368a1903
commit 4ff19ecf66
6 changed files with 85 additions and 231 deletions

View File

@ -2650,11 +2650,6 @@
"torch.export.graph_signature": [ "torch.export.graph_signature": [
"TokenArgument" "TokenArgument"
], ],
"torch.export.pt2_archive": [
"PT2ArchiveWriter",
"PT2ArchiveReader",
"is_pt2_package"
],
"torch.fx.experimental.shape_inference.infer_shape": [ "torch.fx.experimental.shape_inference.infer_shape": [
"DimDynamic", "DimDynamic",
"FakeTensorMode", "FakeTensorMode",

View File

@ -216,10 +216,8 @@ class TestAOTInductorPackage(TestCase):
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile( with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
package_path, "r" package_path, "r"
) as zip_ref: ) as zip_ref:
filenames = zip_ref.namelist()
prefix = filenames[0].split("/")[0]
zip_ref.extractall(tmp_dir) zip_ref.extractall(tmp_dir)
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model" tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists()) self.assertTrue(tmp_path.exists())
if self.device == GPU_TYPE: if self.device == GPU_TYPE:
kernel_bin = get_kernel_bin_format(self.device) kernel_bin = get_kernel_bin_format(self.device)

View File

@ -3,7 +3,10 @@ import json
import logging import logging
import os import os
import tempfile import tempfile
import zipfile
from pathlib import Path
from typing import Any, IO, Optional, Union from typing import Any, IO, Optional, Union
from typing_extensions import Self
import torch import torch
import torch._inductor import torch._inductor
@ -11,9 +14,9 @@ import torch.utils._pytree as pytree
from torch._inductor import config from torch._inductor import config
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs from torch.export._tree_utils import reorder_kwargs
from torch.export.pt2_archive._package import PT2ArchiveWriter
from torch.export.pt2_archive.constants import ( from torch.export.pt2_archive.constants import (
AOTINDUCTOR_DIR, AOTINDUCTOR_DIR,
ARCHIVE_VERSION_VALUE,
CONSTANTS_DIR, CONSTANTS_DIR,
CUSTOM_OBJ_FILENAME_PREFIX, CUSTOM_OBJ_FILENAME_PREFIX,
) )
@ -23,6 +26,74 @@ from torch.types import FileLike
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class PT2ArchiveWriter:
def __init__(self, archive_path: FileLike) -> None:
self.archive_path: FileLike = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> Self:
assert self.archive_file is None
self.archive_file = zipfile.ZipFile(
self.archive_path, "w", compression=zipfile.ZIP_STORED
)
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
self.writestr("archive_format", "pt2")
return self
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
assert self.archive_file is not None
self.archive_file.close()
self.archive_file = None
return None
def writestr(self, name: str, data: Union[bytes, str]) -> None:
assert self.archive_file is not None
self.archive_file.writestr(name, data)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
assert self.archive_file is not None
self.archive_file.write(file_path, arcname=name)
class PT2ArchiveReader:
def __init__(self, archive_path: str) -> None:
self.archive_path: str = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> Self:
self.archive_file = zipfile.ZipFile(
self.archive_path, "r", compression=zipfile.ZIP_STORED
)
return self
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
if self.archive_file is not None:
self.archive_file.close()
return None
def read(self, name: str) -> bytes:
assert self.archive_file is not None
return self.archive_file.read(name)
def extract_to_path(self, member: str, path: str) -> str:
assert self.archive_file is not None
return self.archive_file.extract(member, path)
def extractall(self, path: str) -> None:
assert self.archive_file is not None
self.archive_file.extractall(path)
def get_file_names(self) -> list[str]:
assert self.archive_file is not None
return self.archive_file.namelist()
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str: def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def get_aoti_file_with_suffix(suffix: str) -> str: def get_aoti_file_with_suffix(suffix: str) -> str:
for file in aoti_files: for file in aoti_files:

View File

@ -367,7 +367,15 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
} }
std::vector<std::string> found_filenames; temp_dir_ = create_temp_dir();
std::string so_filename;
std::string cpp_filename;
std::vector<std::string> obj_filenames;
std::string found_filenames; // Saving for bookkeeping
std::string model_directory =
"data" + k_separator + "aotinductor" + k_separator + model_name;
std::string const_directory = "data" + k_separator + "constants";
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
uint32_t filename_len = uint32_t filename_len =
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
@ -381,40 +389,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
&zip_archive, i, filename_str.data(), filename_len)) { &zip_archive, i, filename_str.data(), filename_len)) {
throw std::runtime_error("Failed to read filename"); throw std::runtime_error("Failed to read filename");
} }
found_filenames.push_back(filename_str);
}
if (found_filenames.empty()) { found_filenames += filename_str;
throw std::runtime_error("No files found in zip archive."); found_filenames += " ";
}
// All the paths are prepended with a tmp/ directory. We need to find the
// prefix.
std::string file_prefix;
size_t pos = found_filenames[0].find('/');
std::string prefix0 = found_filenames[0].substr(0, pos);
pos = found_filenames[1].find('/');
std::string prefix1 = found_filenames[1].substr(0, pos);
if (!prefix0.empty() && !prefix1.empty() && prefix0 == prefix1) {
file_prefix = prefix0 + "/";
} else {
LOG(WARNING)
<< "You are using an outdated version of the pt2 archive which do not have a prefix in front of each filename. Example: \n"
<< found_filenames[0] << "\n"
<< found_filenames[1];
}
temp_dir_ = create_temp_dir();
std::string so_filename;
std::string cpp_filename;
std::vector<std::string> obj_filenames;
std::string model_directory = file_prefix + "data" + k_separator +
"aotinductor" + k_separator + model_name;
std::string const_directory = "data" + k_separator + "constants";
for (const std::string& filename_str : found_filenames) {
// Only compile files in the specified model directory // Only compile files in the specified model directory
if (c10::starts_with(filename_str, model_directory) || if (c10::starts_with(filename_str, model_directory) ||
c10::starts_with(filename_str, const_directory)) { c10::starts_with(filename_str, const_directory)) {
@ -482,13 +460,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
} }
if (cpp_filename.empty() && so_filename.empty()) { if (cpp_filename.empty() && so_filename.empty()) {
std::string found_filenames_str;
for (const std::string& filename : found_filenames) {
found_filenames_str += filename + "\n";
}
throw std::runtime_error( throw std::runtime_error(
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" + "No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
found_filenames_str); found_filenames);
} }
// Compile the .so // Compile the .so

View File

@ -1,4 +0,0 @@
from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter
__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"]

View File

@ -1,180 +0,0 @@
# pyre-unsafe
import glob
import io
import logging
import os
import zipfile
from typing import Any, Union
import torch
from torch.export.pt2_archive.constants import (
ARCHIVE_FORMAT_PATH,
ARCHIVE_FORMAT_VALUE,
ARCHIVE_VERSION_PATH,
ARCHIVE_VERSION_VALUE,
)
from torch.types import FileLike
logger: logging.Logger = logging.getLogger(__name__)
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
"""
Check if the serialized model is a PT2 Archive package.
"""
try:
zip_reader = zipfile.ZipFile(
io.BytesIO(serialized_model)
if isinstance(serialized_model, bytes)
else serialized_model
)
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
if archive_format_path in zip_reader.namelist():
return zip_reader.read(archive_format_path) == b"pt2"
except Exception as ex:
logger.info("Model is not a PT2 package: %s", str(ex))
return False
class PT2ArchiveWriter:
"""
Context manager for writing a PT2 archive.
"""
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
# NOTICE: version here is different from the archive_version
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
# archive_version is the version of the PT2 archive spec, which write to /archive_version
self.archive_file.set_min_version(6)
def __enter__(self) -> "PT2ArchiveWriter":
return self
def __exit__(self, *args: Any) -> None:
if not self.has_record(ARCHIVE_FORMAT_PATH):
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
if not self.has_record(ARCHIVE_VERSION_PATH):
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
self.close()
def has_record(self, name: str) -> bool:
"""
Check if a record exists in the archive.
"""
return name in self.archive_file.get_all_written_records()
def count_prefix(self, prefix: str) -> int:
"""
Count the number of records that start with a given prefix.
"""
return sum(
1
for record in self.archive_file.get_all_written_records()
if record.startswith(prefix)
)
def write_bytes(self, name: str, data: bytes) -> None:
"""
Write a bytes object to the archive.
name: The destination file inside the archive.
data: The bytes object to write.
"""
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
self.archive_file.write_record(name, data, len(data))
def write_string(self, name: str, data: str) -> None:
"""
Write a string object to the archive.
name: The destination file inside the archive.
data: The string object to write.
"""
assert isinstance(data, str), f"Expected string but got {type(data)}"
data_bytes = data.encode()
self.write_bytes(name, data_bytes)
def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
with open(file_path, "rb") as f:
file_bytes = f.read()
self.write_bytes(name, file_bytes)
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
"""
Copy a folder into the archive.
archive_dir: The destination folder inside the archive.
folder_dir: The source folder on disk.
"""
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
file_paths = filter(
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
self.write_file(archive_path, file_path)
def close(self) -> None:
"""
Close the archive.
"""
self.archive_file.write_end_of_file()
class PT2ArchiveReader:
"""
Context manager for reading a PT2 archive.
"""
def __init__(self, archive_path_or_buffer: FileLike):
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
assert (
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
), "Invalid archive format"
def __enter__(self) -> "PT2ArchiveReader":
return self
def __exit__(self, *args: Any) -> None:
# torch._C.PyTorchFileReader doesn't have a close method
pass
def read_bytes(self, name: str) -> bytes:
"""
Read a bytes object from the archive.
name: The source file inside the archive.
"""
return self.archive_file.get_record(name)
def read_string(self, name: str) -> str:
"""
Read a string object from the archive.
name: The source file inside the archive.
"""
data = self.read_bytes(name)
return data.decode()
def archive_version(self) -> int:
"""
Get the archive version.
"""
try:
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
except Exception:
# if archive_version is not found, it means the archive is older than version 0.
# In this case, we assume the archive is version 0.
archive_version = "0"
return int(archive_version)