mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[export] Move PT2ArchiveWriter/Reader to torch/export (#153795)"
This reverts commit7e80f23516. Reverted https://github.com/pytorch/pytorch/pull/153795 on behalf of https://github.com/malfet due to Looks like it broke lots of tests, seeec368a1903/1([comment](https://github.com/pytorch/pytorch/pull/153795#issuecomment-2905415496))
This commit is contained in:
parent
ec368a1903
commit
4ff19ecf66
|
|
@ -2650,11 +2650,6 @@
|
||||||
"torch.export.graph_signature": [
|
"torch.export.graph_signature": [
|
||||||
"TokenArgument"
|
"TokenArgument"
|
||||||
],
|
],
|
||||||
"torch.export.pt2_archive": [
|
|
||||||
"PT2ArchiveWriter",
|
|
||||||
"PT2ArchiveReader",
|
|
||||||
"is_pt2_package"
|
|
||||||
],
|
|
||||||
"torch.fx.experimental.shape_inference.infer_shape": [
|
"torch.fx.experimental.shape_inference.infer_shape": [
|
||||||
"DimDynamic",
|
"DimDynamic",
|
||||||
"FakeTensorMode",
|
"FakeTensorMode",
|
||||||
|
|
|
||||||
|
|
@ -216,10 +216,8 @@ class TestAOTInductorPackage(TestCase):
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
|
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
|
||||||
package_path, "r"
|
package_path, "r"
|
||||||
) as zip_ref:
|
) as zip_ref:
|
||||||
filenames = zip_ref.namelist()
|
|
||||||
prefix = filenames[0].split("/")[0]
|
|
||||||
zip_ref.extractall(tmp_dir)
|
zip_ref.extractall(tmp_dir)
|
||||||
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
|
||||||
self.assertTrue(tmp_path.exists())
|
self.assertTrue(tmp_path.exists())
|
||||||
if self.device == GPU_TYPE:
|
if self.device == GPU_TYPE:
|
||||||
kernel_bin = get_kernel_bin_format(self.device)
|
kernel_bin = get_kernel_bin_format(self.device)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,10 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, IO, Optional, Union
|
from typing import Any, IO, Optional, Union
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor
|
import torch._inductor
|
||||||
|
|
@ -11,9 +14,9 @@ import torch.utils._pytree as pytree
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
||||||
from torch.export._tree_utils import reorder_kwargs
|
from torch.export._tree_utils import reorder_kwargs
|
||||||
from torch.export.pt2_archive._package import PT2ArchiveWriter
|
|
||||||
from torch.export.pt2_archive.constants import (
|
from torch.export.pt2_archive.constants import (
|
||||||
AOTINDUCTOR_DIR,
|
AOTINDUCTOR_DIR,
|
||||||
|
ARCHIVE_VERSION_VALUE,
|
||||||
CONSTANTS_DIR,
|
CONSTANTS_DIR,
|
||||||
CUSTOM_OBJ_FILENAME_PREFIX,
|
CUSTOM_OBJ_FILENAME_PREFIX,
|
||||||
)
|
)
|
||||||
|
|
@ -23,6 +26,74 @@ from torch.types import FileLike
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PT2ArchiveWriter:
|
||||||
|
def __init__(self, archive_path: FileLike) -> None:
|
||||||
|
self.archive_path: FileLike = archive_path
|
||||||
|
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
assert self.archive_file is None
|
||||||
|
self.archive_file = zipfile.ZipFile(
|
||||||
|
self.archive_path, "w", compression=zipfile.ZIP_STORED
|
||||||
|
)
|
||||||
|
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
|
||||||
|
self.writestr("archive_format", "pt2")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
||||||
|
assert self.archive_file is not None
|
||||||
|
self.archive_file.close()
|
||||||
|
self.archive_file = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def writestr(self, name: str, data: Union[bytes, str]) -> None:
|
||||||
|
assert self.archive_file is not None
|
||||||
|
self.archive_file.writestr(name, data)
|
||||||
|
|
||||||
|
def write_file(self, name: str, file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Copy a file into the archive.
|
||||||
|
name: The destination file inside the archive.
|
||||||
|
file_path: The source file on disk.
|
||||||
|
"""
|
||||||
|
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
|
||||||
|
assert self.archive_file is not None
|
||||||
|
self.archive_file.write(file_path, arcname=name)
|
||||||
|
|
||||||
|
|
||||||
|
class PT2ArchiveReader:
|
||||||
|
def __init__(self, archive_path: str) -> None:
|
||||||
|
self.archive_path: str = archive_path
|
||||||
|
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||||
|
|
||||||
|
def __enter__(self) -> Self:
|
||||||
|
self.archive_file = zipfile.ZipFile(
|
||||||
|
self.archive_path, "r", compression=zipfile.ZIP_STORED
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
||||||
|
if self.archive_file is not None:
|
||||||
|
self.archive_file.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read(self, name: str) -> bytes:
|
||||||
|
assert self.archive_file is not None
|
||||||
|
return self.archive_file.read(name)
|
||||||
|
|
||||||
|
def extract_to_path(self, member: str, path: str) -> str:
|
||||||
|
assert self.archive_file is not None
|
||||||
|
return self.archive_file.extract(member, path)
|
||||||
|
|
||||||
|
def extractall(self, path: str) -> None:
|
||||||
|
assert self.archive_file is not None
|
||||||
|
self.archive_file.extractall(path)
|
||||||
|
|
||||||
|
def get_file_names(self) -> list[str]:
|
||||||
|
assert self.archive_file is not None
|
||||||
|
return self.archive_file.namelist()
|
||||||
|
|
||||||
|
|
||||||
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
||||||
def get_aoti_file_with_suffix(suffix: str) -> str:
|
def get_aoti_file_with_suffix(suffix: str) -> str:
|
||||||
for file in aoti_files:
|
for file in aoti_files:
|
||||||
|
|
|
||||||
|
|
@ -367,7 +367,15 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||||
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
|
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> found_filenames;
|
temp_dir_ = create_temp_dir();
|
||||||
|
std::string so_filename;
|
||||||
|
std::string cpp_filename;
|
||||||
|
std::vector<std::string> obj_filenames;
|
||||||
|
std::string found_filenames; // Saving for bookkeeping
|
||||||
|
std::string model_directory =
|
||||||
|
"data" + k_separator + "aotinductor" + k_separator + model_name;
|
||||||
|
std::string const_directory = "data" + k_separator + "constants";
|
||||||
|
|
||||||
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
|
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
|
||||||
uint32_t filename_len =
|
uint32_t filename_len =
|
||||||
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
|
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
|
||||||
|
|
@ -381,40 +389,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||||
&zip_archive, i, filename_str.data(), filename_len)) {
|
&zip_archive, i, filename_str.data(), filename_len)) {
|
||||||
throw std::runtime_error("Failed to read filename");
|
throw std::runtime_error("Failed to read filename");
|
||||||
}
|
}
|
||||||
found_filenames.push_back(filename_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (found_filenames.empty()) {
|
found_filenames += filename_str;
|
||||||
throw std::runtime_error("No files found in zip archive.");
|
found_filenames += " ";
|
||||||
}
|
|
||||||
|
|
||||||
// All the paths are prepended with a tmp/ directory. We need to find the
|
|
||||||
// prefix.
|
|
||||||
std::string file_prefix;
|
|
||||||
size_t pos = found_filenames[0].find('/');
|
|
||||||
std::string prefix0 = found_filenames[0].substr(0, pos);
|
|
||||||
pos = found_filenames[1].find('/');
|
|
||||||
std::string prefix1 = found_filenames[1].substr(0, pos);
|
|
||||||
|
|
||||||
if (!prefix0.empty() && !prefix1.empty() && prefix0 == prefix1) {
|
|
||||||
file_prefix = prefix0 + "/";
|
|
||||||
} else {
|
|
||||||
LOG(WARNING)
|
|
||||||
<< "You are using an outdated version of the pt2 archive which do not have a prefix in front of each filename. Example: \n"
|
|
||||||
<< found_filenames[0] << "\n"
|
|
||||||
<< found_filenames[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
temp_dir_ = create_temp_dir();
|
|
||||||
|
|
||||||
std::string so_filename;
|
|
||||||
std::string cpp_filename;
|
|
||||||
std::vector<std::string> obj_filenames;
|
|
||||||
std::string model_directory = file_prefix + "data" + k_separator +
|
|
||||||
"aotinductor" + k_separator + model_name;
|
|
||||||
std::string const_directory = "data" + k_separator + "constants";
|
|
||||||
|
|
||||||
for (const std::string& filename_str : found_filenames) {
|
|
||||||
// Only compile files in the specified model directory
|
// Only compile files in the specified model directory
|
||||||
if (c10::starts_with(filename_str, model_directory) ||
|
if (c10::starts_with(filename_str, model_directory) ||
|
||||||
c10::starts_with(filename_str, const_directory)) {
|
c10::starts_with(filename_str, const_directory)) {
|
||||||
|
|
@ -482,13 +460,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cpp_filename.empty() && so_filename.empty()) {
|
if (cpp_filename.empty() && so_filename.empty()) {
|
||||||
std::string found_filenames_str;
|
|
||||||
for (const std::string& filename : found_filenames) {
|
|
||||||
found_filenames_str += filename + "\n";
|
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
|
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
|
||||||
found_filenames_str);
|
found_filenames);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile the .so
|
// Compile the .so
|
||||||
|
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"]
|
|
||||||
|
|
@ -1,180 +0,0 @@
|
||||||
# pyre-unsafe
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import zipfile
|
|
||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.export.pt2_archive.constants import (
|
|
||||||
ARCHIVE_FORMAT_PATH,
|
|
||||||
ARCHIVE_FORMAT_VALUE,
|
|
||||||
ARCHIVE_VERSION_PATH,
|
|
||||||
ARCHIVE_VERSION_VALUE,
|
|
||||||
)
|
|
||||||
from torch.types import FileLike
|
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the serialized model is a PT2 Archive package.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
zip_reader = zipfile.ZipFile(
|
|
||||||
io.BytesIO(serialized_model)
|
|
||||||
if isinstance(serialized_model, bytes)
|
|
||||||
else serialized_model
|
|
||||||
)
|
|
||||||
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
|
|
||||||
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
|
|
||||||
if archive_format_path in zip_reader.namelist():
|
|
||||||
return zip_reader.read(archive_format_path) == b"pt2"
|
|
||||||
except Exception as ex:
|
|
||||||
logger.info("Model is not a PT2 package: %s", str(ex))
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class PT2ArchiveWriter:
|
|
||||||
"""
|
|
||||||
Context manager for writing a PT2 archive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, archive_path_or_buffer: FileLike):
|
|
||||||
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
|
|
||||||
# NOTICE: version here is different from the archive_version
|
|
||||||
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
|
|
||||||
# archive_version is the version of the PT2 archive spec, which write to /archive_version
|
|
||||||
self.archive_file.set_min_version(6)
|
|
||||||
|
|
||||||
def __enter__(self) -> "PT2ArchiveWriter":
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any) -> None:
|
|
||||||
if not self.has_record(ARCHIVE_FORMAT_PATH):
|
|
||||||
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
|
|
||||||
|
|
||||||
if not self.has_record(ARCHIVE_VERSION_PATH):
|
|
||||||
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
|
|
||||||
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def has_record(self, name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a record exists in the archive.
|
|
||||||
"""
|
|
||||||
return name in self.archive_file.get_all_written_records()
|
|
||||||
|
|
||||||
def count_prefix(self, prefix: str) -> int:
|
|
||||||
"""
|
|
||||||
Count the number of records that start with a given prefix.
|
|
||||||
"""
|
|
||||||
return sum(
|
|
||||||
1
|
|
||||||
for record in self.archive_file.get_all_written_records()
|
|
||||||
if record.startswith(prefix)
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_bytes(self, name: str, data: bytes) -> None:
|
|
||||||
"""
|
|
||||||
Write a bytes object to the archive.
|
|
||||||
name: The destination file inside the archive.
|
|
||||||
data: The bytes object to write.
|
|
||||||
"""
|
|
||||||
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
|
|
||||||
self.archive_file.write_record(name, data, len(data))
|
|
||||||
|
|
||||||
def write_string(self, name: str, data: str) -> None:
|
|
||||||
"""
|
|
||||||
Write a string object to the archive.
|
|
||||||
name: The destination file inside the archive.
|
|
||||||
data: The string object to write.
|
|
||||||
"""
|
|
||||||
assert isinstance(data, str), f"Expected string but got {type(data)}"
|
|
||||||
data_bytes = data.encode()
|
|
||||||
self.write_bytes(name, data_bytes)
|
|
||||||
|
|
||||||
def write_file(self, name: str, file_path: str) -> None:
|
|
||||||
"""
|
|
||||||
Copy a file into the archive.
|
|
||||||
name: The destination file inside the archive.
|
|
||||||
file_path: The source file on disk.
|
|
||||||
"""
|
|
||||||
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
|
|
||||||
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
file_bytes = f.read()
|
|
||||||
self.write_bytes(name, file_bytes)
|
|
||||||
|
|
||||||
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
|
|
||||||
"""
|
|
||||||
Copy a folder into the archive.
|
|
||||||
archive_dir: The destination folder inside the archive.
|
|
||||||
folder_dir: The source folder on disk.
|
|
||||||
"""
|
|
||||||
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
|
|
||||||
|
|
||||||
file_paths = filter(
|
|
||||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
|
||||||
)
|
|
||||||
for file_path in file_paths:
|
|
||||||
filename = os.path.relpath(file_path, folder_dir)
|
|
||||||
archive_path = os.path.join(archive_dir, filename)
|
|
||||||
self.write_file(archive_path, file_path)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""
|
|
||||||
Close the archive.
|
|
||||||
"""
|
|
||||||
self.archive_file.write_end_of_file()
|
|
||||||
|
|
||||||
|
|
||||||
class PT2ArchiveReader:
|
|
||||||
"""
|
|
||||||
Context manager for reading a PT2 archive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, archive_path_or_buffer: FileLike):
|
|
||||||
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
|
|
||||||
assert (
|
|
||||||
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
|
|
||||||
), "Invalid archive format"
|
|
||||||
|
|
||||||
def __enter__(self) -> "PT2ArchiveReader":
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any) -> None:
|
|
||||||
# torch._C.PyTorchFileReader doesn't have a close method
|
|
||||||
pass
|
|
||||||
|
|
||||||
def read_bytes(self, name: str) -> bytes:
|
|
||||||
"""
|
|
||||||
Read a bytes object from the archive.
|
|
||||||
name: The source file inside the archive.
|
|
||||||
"""
|
|
||||||
return self.archive_file.get_record(name)
|
|
||||||
|
|
||||||
def read_string(self, name: str) -> str:
|
|
||||||
"""
|
|
||||||
Read a string object from the archive.
|
|
||||||
name: The source file inside the archive.
|
|
||||||
"""
|
|
||||||
data = self.read_bytes(name)
|
|
||||||
return data.decode()
|
|
||||||
|
|
||||||
def archive_version(self) -> int:
|
|
||||||
"""
|
|
||||||
Get the archive version.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
|
|
||||||
except Exception:
|
|
||||||
# if archive_version is not found, it means the archive is older than version 0.
|
|
||||||
# In this case, we assume the archive is version 0.
|
|
||||||
archive_version = "0"
|
|
||||||
|
|
||||||
return int(archive_version)
|
|
||||||
Loading…
Reference in New Issue
Block a user