mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[export] Move PT2ArchiveWriter/Reader to torch/export (#153795)
Summary: Before: `from sigmoid.core.package.pt2_archive import PT2ArchiveWriter, PT2ArchiveReader, is_sigmoid_package` After: `from torch.export.pt2_archive import PT2ArchiveWriter, PT2ArchiveReader, is_pt2_package` By merging the two PT2ArchiveReader/Writers, into using the native PytorchFileReader/Writer, the open source PT2 archive also changed to have an additional folder. However this PR still maintains support for loading an old PT2 archive which does not have the additional folder. Before: ``` ├── archive_format ├── byteorder ├── .data │ ├── serialization_id │ └── version ├── data │ ├── aotinductor ``` After: ``` ├── tmp │ ├── archive_format │ ├── byteorder │ ├── .data │ │ ├── serialization_id │ │ └── version │ ├── data │ │ ├── aotinductor ``` Test Plan: `buck2 test //sigmoid/...` https://www.internalfb.com/intern/testinfra/testrun/5348024839248187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153795 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
499a76b844
commit
3b21d79225
|
|
@ -2650,6 +2650,11 @@
|
|||
"torch.export.graph_signature": [
|
||||
"TokenArgument"
|
||||
],
|
||||
"torch.export.pt2_archive": [
|
||||
"PT2ArchiveWriter",
|
||||
"PT2ArchiveReader",
|
||||
"is_pt2_package"
|
||||
],
|
||||
"torch.fx.experimental.shape_inference.infer_shape": [
|
||||
"DimDynamic",
|
||||
"FakeTensorMode",
|
||||
|
|
|
|||
|
|
@ -216,8 +216,10 @@ class TestAOTInductorPackage(TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
|
||||
package_path, "r"
|
||||
) as zip_ref:
|
||||
filenames = zip_ref.namelist()
|
||||
prefix = filenames[0].split("/")[0]
|
||||
zip_ref.extractall(tmp_dir)
|
||||
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
|
||||
tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model"
|
||||
self.assertTrue(tmp_path.exists())
|
||||
if self.device == GPU_TYPE:
|
||||
kernel_bin = get_kernel_bin_format(self.device)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any, IO, Optional, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
|
@ -14,9 +11,9 @@ import torch.utils._pytree as pytree
|
|||
from torch._inductor import config
|
||||
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
from torch.export.pt2_archive._package import PT2ArchiveWriter
|
||||
from torch.export.pt2_archive.constants import (
|
||||
AOTINDUCTOR_DIR,
|
||||
ARCHIVE_VERSION_VALUE,
|
||||
CONSTANTS_DIR,
|
||||
CUSTOM_OBJ_FILENAME_PREFIX,
|
||||
)
|
||||
|
|
@ -26,74 +23,6 @@ from torch.types import FileLike
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
def __init__(self, archive_path: FileLike) -> None:
|
||||
self.archive_path: FileLike = archive_path
|
||||
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
assert self.archive_file is None
|
||||
self.archive_file = zipfile.ZipFile(
|
||||
self.archive_path, "w", compression=zipfile.ZIP_STORED
|
||||
)
|
||||
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
|
||||
self.writestr("archive_format", "pt2")
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.close()
|
||||
self.archive_file = None
|
||||
return None
|
||||
|
||||
def writestr(self, name: str, data: Union[bytes, str]) -> None:
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.writestr(name, data)
|
||||
|
||||
def write_file(self, name: str, file_path: str) -> None:
|
||||
"""
|
||||
Copy a file into the archive.
|
||||
name: The destination file inside the archive.
|
||||
file_path: The source file on disk.
|
||||
"""
|
||||
assert Path(file_path).is_file(), f"{file_path} is not a valid file path"
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.write(file_path, arcname=name)
|
||||
|
||||
|
||||
class PT2ArchiveReader:
|
||||
def __init__(self, archive_path: str) -> None:
|
||||
self.archive_path: str = archive_path
|
||||
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.archive_file = zipfile.ZipFile(
|
||||
self.archive_path, "r", compression=zipfile.ZIP_STORED
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None: # type: ignore[no-untyped-def]
|
||||
if self.archive_file is not None:
|
||||
self.archive_file.close()
|
||||
return None
|
||||
|
||||
def read(self, name: str) -> bytes:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.read(name)
|
||||
|
||||
def extract_to_path(self, member: str, path: str) -> str:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.extract(member, path)
|
||||
|
||||
def extractall(self, path: str) -> None:
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.extractall(path)
|
||||
|
||||
def get_file_names(self) -> list[str]:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.namelist()
|
||||
|
||||
|
||||
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
||||
def get_aoti_file_with_suffix(suffix: str) -> str:
|
||||
for file in aoti_files:
|
||||
|
|
|
|||
|
|
@ -367,15 +367,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive)));
|
||||
}
|
||||
|
||||
temp_dir_ = create_temp_dir();
|
||||
std::string so_filename;
|
||||
std::string cpp_filename;
|
||||
std::vector<std::string> obj_filenames;
|
||||
std::string found_filenames; // Saving for bookkeeping
|
||||
std::string model_directory =
|
||||
"data" + k_separator + "aotinductor" + k_separator + model_name;
|
||||
std::string const_directory = "data" + k_separator + "constants";
|
||||
|
||||
std::vector<std::string> found_filenames;
|
||||
for (uint32_t i = 0; i < zip_archive.m_total_files; i++) {
|
||||
uint32_t filename_len =
|
||||
mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0);
|
||||
|
|
@ -389,10 +381,40 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
&zip_archive, i, filename_str.data(), filename_len)) {
|
||||
throw std::runtime_error("Failed to read filename");
|
||||
}
|
||||
found_filenames.push_back(filename_str);
|
||||
}
|
||||
|
||||
found_filenames += filename_str;
|
||||
found_filenames += " ";
|
||||
if (found_filenames.empty()) {
|
||||
throw std::runtime_error("No files found in zip archive.");
|
||||
}
|
||||
|
||||
// All the paths are prepended with a tmp/ directory. We need to find the
|
||||
// prefix.
|
||||
std::string file_prefix;
|
||||
size_t pos = found_filenames[0].find('/');
|
||||
std::string prefix0 = found_filenames[0].substr(0, pos);
|
||||
pos = found_filenames[1].find('/');
|
||||
std::string prefix1 = found_filenames[1].substr(0, pos);
|
||||
|
||||
if (!prefix0.empty() && !prefix1.empty() && prefix0 == prefix1) {
|
||||
file_prefix = prefix0 + "/";
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "You are using an outdated version of the pt2 archive which do not have a prefix in front of each filename. Example: \n"
|
||||
<< found_filenames[0] << "\n"
|
||||
<< found_filenames[1];
|
||||
}
|
||||
|
||||
temp_dir_ = create_temp_dir();
|
||||
|
||||
std::string so_filename;
|
||||
std::string cpp_filename;
|
||||
std::vector<std::string> obj_filenames;
|
||||
std::string model_directory = file_prefix + "data" + k_separator +
|
||||
"aotinductor" + k_separator + model_name;
|
||||
std::string const_directory = "data" + k_separator + "constants";
|
||||
|
||||
for (const std::string& filename_str : found_filenames) {
|
||||
// Only compile files in the specified model directory
|
||||
if (c10::starts_with(filename_str, model_directory) ||
|
||||
c10::starts_with(filename_str, const_directory)) {
|
||||
|
|
@ -460,9 +482,13 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
|||
}
|
||||
|
||||
if (cpp_filename.empty() && so_filename.empty()) {
|
||||
std::string found_filenames_str;
|
||||
for (const std::string& filename : found_filenames) {
|
||||
found_filenames_str += filename + "\n";
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" +
|
||||
found_filenames);
|
||||
found_filenames_str);
|
||||
}
|
||||
|
||||
// Compile the .so
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter
|
||||
|
||||
|
||||
__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"]
|
||||
180
torch/export/pt2_archive/_package.py
Normal file
180
torch/export/pt2_archive/_package.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
# pyre-unsafe
|
||||
|
||||
import glob
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch.export.pt2_archive.constants import (
|
||||
ARCHIVE_FORMAT_PATH,
|
||||
ARCHIVE_FORMAT_VALUE,
|
||||
ARCHIVE_VERSION_PATH,
|
||||
ARCHIVE_VERSION_VALUE,
|
||||
)
|
||||
from torch.types import FileLike
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_pt2_package(serialized_model: Union[bytes, str]) -> bool:
|
||||
"""
|
||||
Check if the serialized model is a PT2 Archive package.
|
||||
"""
|
||||
try:
|
||||
zip_reader = zipfile.ZipFile(
|
||||
io.BytesIO(serialized_model)
|
||||
if isinstance(serialized_model, bytes)
|
||||
else serialized_model
|
||||
)
|
||||
root_folder = zip_reader.namelist()[0].split(os.path.sep)[0]
|
||||
archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}"
|
||||
if archive_format_path in zip_reader.namelist():
|
||||
return zip_reader.read(archive_format_path) == b"pt2"
|
||||
except Exception as ex:
|
||||
logger.info("Model is not a PT2 package: %s", str(ex))
|
||||
return False
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
"""
|
||||
Context manager for writing a PT2 archive.
|
||||
"""
|
||||
|
||||
def __init__(self, archive_path_or_buffer: FileLike):
|
||||
self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type]
|
||||
# NOTICE: version here is different from the archive_version
|
||||
# this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version
|
||||
# archive_version is the version of the PT2 archive spec, which write to /archive_version
|
||||
self.archive_file.set_min_version(6)
|
||||
|
||||
def __enter__(self) -> "PT2ArchiveWriter":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
if not self.has_record(ARCHIVE_FORMAT_PATH):
|
||||
self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE)
|
||||
|
||||
if not self.has_record(ARCHIVE_VERSION_PATH):
|
||||
self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE)
|
||||
|
||||
self.close()
|
||||
|
||||
def has_record(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a record exists in the archive.
|
||||
"""
|
||||
return name in self.archive_file.get_all_written_records()
|
||||
|
||||
def count_prefix(self, prefix: str) -> int:
|
||||
"""
|
||||
Count the number of records that start with a given prefix.
|
||||
"""
|
||||
return sum(
|
||||
1
|
||||
for record in self.archive_file.get_all_written_records()
|
||||
if record.startswith(prefix)
|
||||
)
|
||||
|
||||
def write_bytes(self, name: str, data: bytes) -> None:
|
||||
"""
|
||||
Write a bytes object to the archive.
|
||||
name: The destination file inside the archive.
|
||||
data: The bytes object to write.
|
||||
"""
|
||||
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
|
||||
self.archive_file.write_record(name, data, len(data))
|
||||
|
||||
def write_string(self, name: str, data: str) -> None:
|
||||
"""
|
||||
Write a string object to the archive.
|
||||
name: The destination file inside the archive.
|
||||
data: The string object to write.
|
||||
"""
|
||||
assert isinstance(data, str), f"Expected string but got {type(data)}"
|
||||
data_bytes = data.encode()
|
||||
self.write_bytes(name, data_bytes)
|
||||
|
||||
def write_file(self, name: str, file_path: str) -> None:
|
||||
"""
|
||||
Copy a file into the archive.
|
||||
name: The destination file inside the archive.
|
||||
file_path: The source file on disk.
|
||||
"""
|
||||
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
self.write_bytes(name, file_bytes)
|
||||
|
||||
def write_folder(self, archive_dir: str, folder_dir: str) -> None:
|
||||
"""
|
||||
Copy a folder into the archive.
|
||||
archive_dir: The destination folder inside the archive.
|
||||
folder_dir: The source folder on disk.
|
||||
"""
|
||||
assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path"
|
||||
|
||||
file_paths = filter(
|
||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||
)
|
||||
for file_path in file_paths:
|
||||
filename = os.path.relpath(file_path, folder_dir)
|
||||
archive_path = os.path.join(archive_dir, filename)
|
||||
self.write_file(archive_path, file_path)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the archive.
|
||||
"""
|
||||
self.archive_file.write_end_of_file()
|
||||
|
||||
|
||||
class PT2ArchiveReader:
|
||||
"""
|
||||
Context manager for reading a PT2 archive.
|
||||
"""
|
||||
|
||||
def __init__(self, archive_path_or_buffer: FileLike):
|
||||
self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type]
|
||||
assert (
|
||||
self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE
|
||||
), "Invalid archive format"
|
||||
|
||||
def __enter__(self) -> "PT2ArchiveReader":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
# torch._C.PyTorchFileReader doesn't have a close method
|
||||
pass
|
||||
|
||||
def read_bytes(self, name: str) -> bytes:
|
||||
"""
|
||||
Read a bytes object from the archive.
|
||||
name: The source file inside the archive.
|
||||
"""
|
||||
return self.archive_file.get_record(name)
|
||||
|
||||
def read_string(self, name: str) -> str:
|
||||
"""
|
||||
Read a string object from the archive.
|
||||
name: The source file inside the archive.
|
||||
"""
|
||||
data = self.read_bytes(name)
|
||||
return data.decode()
|
||||
|
||||
def archive_version(self) -> int:
|
||||
"""
|
||||
Get the archive version.
|
||||
"""
|
||||
try:
|
||||
archive_version = self.read_string(ARCHIVE_VERSION_PATH)
|
||||
except Exception:
|
||||
# if archive_version is not found, it means the archive is older than version 0.
|
||||
# In this case, we assume the archive is version 0.
|
||||
archive_version = "0"
|
||||
|
||||
return int(archive_version)
|
||||
Loading…
Reference in New Issue
Block a user