mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[export] Move PT2ArchiveWriter/Reader to torch/export (#153795)"
This reverts commit7e80f23516. Reverted https://github.com/pytorch/pytorch/pull/153795 on behalf of https://github.com/malfet due to Looks like it broke lots of tests, seeec368a1903/1([comment](https://github.com/pytorch/pytorch/pull/153795#issuecomment-2905415496))
This commit is contained in:
parent
ec368a1903
commit
4ff19ecf66
|
|
@ -2650,11 +2650,6 @@
|
|||
"torch.export.graph_signature": [
|
||||
"TokenArgument"
|
||||
],
|
||||
"torch.export.pt2_archive": [
|
||||
"PT2ArchiveWriter",
|
||||
"PT2ArchiveReader",
|
||||
"is_pt2_package"
|
||||
],
|
||||
"torch.fx.experimental.shape_inference.infer_shape": [
|
||||
"DimDynamic",
|
||||
"FakeTensorMode",
|
||||
|
|
|
|||
|
|
@ -216,10 +216,8 @@ class TestAOTInductorPackage(TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
|
||||
package_path, "r"
|
||||
) as zip_ref:
|
||||
filenames = zip_ref.namelist()
|
||||
prefix = filenames[0].split("/")[0]
|
||||
zip_ref.extractall(tmp_dir)
|
||||
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
||||
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
|
||||
self.assertTrue(tmp_path.exists())
|
||||
if self.device == GPU_TYPE:
|
||||
kernel_bin = get_kernel_bin_format(self.device)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any, IO, Optional, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
|
@ -11,9 +14,9 @@ import torch.utils._pytree as pytree
|
|||
from torch._inductor import config
|
||||
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
from torch.export.pt2_archive._package import PT2ArchiveWriter
|
||||
from torch.export.pt2_archive.constants import (
|
||||
AOTINDUCTOR_DIR,
|
||||
ARCHIVE_VERSION_VALUE,
|
||||
CONSTANTS_DIR,
|
||||
CUSTOM_OBJ_FILENAME_PREFIX,
|
||||
)
|
||||
|
|
@ -23,6 +26,74 @@ from torch.types import FileLike
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
def __init__(self, archive_path: FileLike) -> None:
|
||||
self.archive_path: FileLike = archive_path
|
||||
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
assert self.archive_file is None
|
||||
self.archive_file = zipfile.ZipFile(
|
||||
self.archive_path, "w", compression=zipfile.ZIP_STORED
|
||||
)
|
||||
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
|
||||
self.writestr("archive_format", "pt2")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.close()
|
||||
self.archive_file = None
|
||||
return None
|
||||
|
||||
def writestr(self, name: str, data: Union[bytes, str]) -> None:
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.writestr(name, data)
|
||||
|
||||
def write_file(self, name: str, file_path: str) -> None:
|
||||
"""
|
||||
Copy a file into the archive.
|
||||
name: The destination file inside the archive.
|
||||
file_path: The source file on disk.
|
||||
"""
|
||||
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.write(file_path, arcname=name)
|
||||
|
||||
|
||||
class PT2ArchiveReader:
|
||||
def __init__(self, archive_path: str) -> None:
|
||||
self.archive_path: str = archive_path
|
||||
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.archive_file = zipfile.ZipFile(
|
||||
self.archive_path, "r", compression=zipfile.ZIP_STORED
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
||||
if self.archive_file is not None:
|
||||
self.archive_file.close()
|
||||
return None
|
||||
|
||||
def read(self, name: str) -> bytes:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.read(name)
|
||||
|
||||
def extract_to_path(self, member: str, path: str) -> str:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.extract(member, path)
|
||||
|
||||
def extractall(self, path: str) -> None:
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.extractall(path)
|
||||
|
||||
def get_file_names(self) -> list[str]:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.namelist()
|
||||
|
||||
|
||||
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
||||
def get_aoti_file_with_suffix(suffix: str) -> str:
|
||||
for file in aoti_files:
|
||||
|
|
|
|||
|
|
@ -367,7 +367,15 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
|
||||
}
|
||||
|
||||
std::vector<std::string> found_filenames;
|
||||
temp_dir_ = create_temp_dir();
|
||||
std::string so_filename;
|
||||
std::string cpp_filename;
|
||||
std::vector<std::string> obj_filenames;
|
||||
std::string found_filenames; // Saving for bookkeeping
|
||||
std::string model_directory =
|
||||
"data" + k_separator + "aotinductor" + k_separator + model_name;
|
||||
std::string const_directory = "data" + k_separator + "constants";
|
||||
|
||||
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
|
||||
uint32_t filename_len =
|
||||
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
|
||||
|
|
@ -381,40 +389,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
&zip_archive, i, filename_str.data(), filename_len)) {
|
||||
throw std::runtime_error("Failed to read filename");
|
||||
}
|
||||
found_filenames.push_back(filename_str);
|
||||
}
|
||||
|
||||
if (found_filenames.empty()) {
|
||||
throw std::runtime_error("No files found in zip archive.");
|
||||
}
|
||||
found_filenames += filename_str;
|
||||
found_filenames += " ";
|
||||
|
||||
// All the paths are prepended with a tmp/ directory. We need to find the
|
||||
// prefix.
|
||||
std::string file_prefix;
|
||||
size_t pos = found_filenames[0].find('/');
|
||||
std::string prefix0 = found_filenames[0].substr(0, pos);
|
||||
pos = found_filenames[1].find('/');
|
||||
std::string prefix1 = found_filenames[1].substr(0, pos);
|
||||
|
||||
if (!prefix0.empty() && !prefix1.empty() && prefix0 == prefix1) {
|
||||
file_prefix = prefix0 + "/";
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "You are using an outdated version of the pt2 archive which do not have a prefix in front of each filename. Example: \n"
|
||||
<< found_filenames[0] << "\n"
|
||||
<< found_filenames[1];
|
||||
}
|
||||
|
||||
temp_dir_ = create_temp_dir();
|
||||
|
||||
std::string so_filename;
|
||||
std::string cpp_filename;
|
||||
std::vector<std::string> obj_filenames;
|
||||
std::string model_directory = file_prefix + "data" + k_separator +
|
||||
"aotinductor" + k_separator + model_name;
|
||||
std::string const_directory = "data" + k_separator + "constants";
|
||||
|
||||
for (const std::string& filename_str : found_filenames) {
|
||||
// Only compile files in the specified model directory
|
||||
if (c10::starts_with(filename_str, model_directory) ||
|
||||
c10::starts_with(filename_str, const_directory)) {
|
||||
|
|
@ -482,13 +460,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
}
|
||||
|
||||
if (cpp_filename.empty() && so_filename.empty()) {
|
||||
std::string found_filenames_str;
|
||||
for (const std::string& filename : found_filenames) {
|
||||
found_filenames_str += filename + "\n";
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
|
||||
found_filenames_str);
|
||||
found_filenames);
|
||||
}
|
||||
|
||||
// Compile the .so
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter
|
||||
|
||||
|
||||
__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"]
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
# pyre-unsafe
|
||||
|
||||
import glob
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch.export.pt2_archive.constants import (
|
||||
ARCHIVE_FORMAT_PATH,
|
||||
ARCHIVE_FORMAT_VALUE,
|
||||
ARCHIVE_VERSION_PATH,
|
||||
ARCHIVE_VERSION_VALUE,
|
||||
)
|
||||
from torch.types import FileLike
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
|
||||
"""
|
||||
Check if the serialized model is a PT2 Archive package.
|
||||
"""
|
||||
try:
|
||||
zip_reader = zipfile.ZipFile(
|
||||
io.BytesIO(serialized_model)
|
||||
if isinstance(serialized_model, bytes)
|
||||
else serialized_model
|
||||
)
|
||||
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
|
||||
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
|
||||
if archive_format_path in zip_reader.namelist():
|
||||
return zip_reader.read(archive_format_path) == b"pt2"
|
||||
except Exception as ex:
|
||||
logger.info("Model is not a PT2 package: %s", str(ex))
|
||||
return False
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
"""
|
||||
Context manager for writing a PT2 archive.
|
||||
"""
|
||||
|
||||
def __init__(self, archive_path_or_buffer: FileLike):
|
||||
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
|
||||
# NOTICE: version here is different from the archive_version
|
||||
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
|
||||
# archive_version is the version of the PT2 archive spec, which write to /archive_version
|
||||
self.archive_file.set_min_version(6)
|
||||
|
||||
def __enter__(self) -> "PT2ArchiveWriter":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
if not self.has_record(ARCHIVE_FORMAT_PATH):
|
||||
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
|
||||
|
||||
if not self.has_record(ARCHIVE_VERSION_PATH):
|
||||
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
|
||||
|
||||
self.close()
|
||||
|
||||
def has_record(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a record exists in the archive.
|
||||
"""
|
||||
return name in self.archive_file.get_all_written_records()
|
||||
|
||||
def count_prefix(self, prefix: str) -> int:
|
||||
"""
|
||||
Count the number of records that start with a given prefix.
|
||||
"""
|
||||
return sum(
|
||||
1
|
||||
for record in self.archive_file.get_all_written_records()
|
||||
if record.startswith(prefix)
|
||||
)
|
||||
|
||||
def write_bytes(self, name: str, data: bytes) -> None:
|
||||
"""
|
||||
Write a bytes object to the archive.
|
||||
name: The destination file inside the archive.
|
||||
data: The bytes object to write.
|
||||
"""
|
||||
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
|
||||
self.archive_file.write_record(name, data, len(data))
|
||||
|
||||
def write_string(self, name: str, data: str) -> None:
|
||||
"""
|
||||
Write a string object to the archive.
|
||||
name: The destination file inside the archive.
|
||||
data: The string object to write.
|
||||
"""
|
||||
assert isinstance(data, str), f"Expected string but got {type(data)}"
|
||||
data_bytes = data.encode()
|
||||
self.write_bytes(name, data_bytes)
|
||||
|
||||
def write_file(self, name: str, file_path: str) -> None:
|
||||
"""
|
||||
Copy a file into the archive.
|
||||
name: The destination file inside the archive.
|
||||
file_path: The source file on disk.
|
||||
"""
|
||||
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
self.write_bytes(name, file_bytes)
|
||||
|
||||
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
|
||||
"""
|
||||
Copy a folder into the archive.
|
||||
archive_dir: The destination folder inside the archive.
|
||||
folder_dir: The source folder on disk.
|
||||
"""
|
||||
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
|
||||
|
||||
file_paths = filter(
|
||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||
)
|
||||
for file_path in file_paths:
|
||||
filename = os.path.relpath(file_path, folder_dir)
|
||||
archive_path = os.path.join(archive_dir, filename)
|
||||
self.write_file(archive_path, file_path)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the archive.
|
||||
"""
|
||||
self.archive_file.write_end_of_file()
|
||||
|
||||
|
||||
class PT2ArchiveReader:
|
||||
"""
|
||||
Context manager for reading a PT2 archive.
|
||||
"""
|
||||
|
||||
def __init__(self, archive_path_or_buffer: FileLike):
|
||||
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
|
||||
assert (
|
||||
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
|
||||
), "Invalid archive format"
|
||||
|
||||
def __enter__(self) -> "PT2ArchiveReader":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
# torch._C.PyTorchFileReader doesn't have a close method
|
||||
pass
|
||||
|
||||
def read_bytes(self, name: str) -> bytes:
|
||||
"""
|
||||
Read a bytes object from the archive.
|
||||
name: The source file inside the archive.
|
||||
"""
|
||||
return self.archive_file.get_record(name)
|
||||
|
||||
def read_string(self, name: str) -> str:
|
||||
"""
|
||||
Read a string object from the archive.
|
||||
name: The source file inside the archive.
|
||||
"""
|
||||
data = self.read_bytes(name)
|
||||
return data.decode()
|
||||
|
||||
def archive_version(self) -> int:
|
||||
"""
|
||||
Get the archive version.
|
||||
"""
|
||||
try:
|
||||
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
|
||||
except Exception:
|
||||
# if archive_version is not found, it means the archive is older than version 0.
|
||||
# In this case, we assume the archive is version 0.
|
||||
archive_version = "0"
|
||||
|
||||
return int(archive_version)
|
||||
Loading…
Reference in New Issue
Block a user