mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: Add better handling of pathlib.Path with os calls (#116564)
Builds on #116562 to the rest of the instances of pathlib in the PyTorch. * Uses more generic `os.PathLike` and `os.fspath` calls where appropiate Pull Request resolved: https://github.com/pytorch/pytorch/pull/116564 Approved by: https://github.com/malfet
This commit is contained in:
parent
86cd6655a1
commit
aef06c316b
|
|
@ -3,7 +3,6 @@ import dataclasses
|
|||
import functools
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
|
|
@ -207,7 +206,7 @@ def export(
|
|||
|
||||
def save(
|
||||
ep: ExportedProgram,
|
||||
f: Union[str, pathlib.Path, io.BytesIO],
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
opset_version: Optional[Dict[str, int]] = None,
|
||||
|
|
@ -216,8 +215,8 @@ def save(
|
|||
from .serde.schema import SCHEMA_VERSION
|
||||
artifact: SerializedArtifact = serialize(ep, opset_version)
|
||||
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
f = str(f)
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
with zipfile.ZipFile(f, 'w') as zipf:
|
||||
# Save every field the SerializedArtifact to a file
|
||||
|
|
@ -236,13 +235,13 @@ def save(
|
|||
|
||||
|
||||
def load(
|
||||
f: Union[str, pathlib.Path, io.BytesIO],
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[Dict[str, int]] = None,
|
||||
) -> ExportedProgram:
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
f = str(f)
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
with zipfile.ZipFile(f, 'r') as zipf:
|
||||
# Check the version
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import copy
|
|||
import dataclasses
|
||||
import inspect
|
||||
import io
|
||||
import pathlib
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
|
|
@ -200,7 +200,7 @@ def export(
|
|||
|
||||
def save(
|
||||
ep: ExportedProgram,
|
||||
f: Union[str, pathlib.Path, io.BytesIO],
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
opset_version: Optional[Dict[str, int]] = None,
|
||||
|
|
@ -217,7 +217,7 @@ def save(
|
|||
Args:
|
||||
ep (ExportedProgram): The exported program to save.
|
||||
|
||||
f (Union[str, pathlib.Path, io.BytesIO): A file-like object (has to
|
||||
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
|
||||
implement write and flush) or a string containing a file name.
|
||||
|
||||
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
|
||||
|
|
@ -256,7 +256,7 @@ def save(
|
|||
|
||||
|
||||
def load(
|
||||
f: Union[str, pathlib.Path, io.BytesIO],
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
*,
|
||||
extra_files: Optional[Dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[Dict[str, int]] = None,
|
||||
|
|
@ -273,7 +273,7 @@ def load(
|
|||
Args:
|
||||
ep (ExportedProgram): The exported program to save.
|
||||
|
||||
f (Union[str, pathlib.Path, io.BytesIO): A file-like object (has to
|
||||
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
|
||||
implement write and flush) or a string containing a file name.
|
||||
|
||||
extra_files (Optional[Dict[str, Any]]): The extra filenames given in
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ This is not intended to be imported directly; please use the exposed
|
|||
functionalities in `torch.jit`.
|
||||
"""
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
from torch.jit._recursive import wrap_cpp_module
|
||||
|
|
@ -76,7 +75,7 @@ def save(m, f, _extra_files=None):
|
|||
"""
|
||||
if _extra_files is None:
|
||||
_extra_files = {}
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
m.save(f, _extra_files=_extra_files)
|
||||
else:
|
||||
ret = m.save_to_buffer(_extra_files=_extra_files)
|
||||
|
|
@ -155,8 +154,8 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
|||
_extra_files = {}
|
||||
|
||||
cu = torch._C.CompilationUnit()
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
|
||||
else:
|
||||
cpp_module = torch._C.import_ir_module_from_buffer(
|
||||
cu, f.read(), map_location, _extra_files, _restore_shapes
|
||||
|
|
@ -182,8 +181,8 @@ def validate_map_location(map_location=None):
|
|||
|
||||
|
||||
def jit_module_from_flatbuffer(f):
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
f = str(f)
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
|
||||
else:
|
||||
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
|
||||
|
|
@ -231,8 +230,8 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
|||
if extra_files is None:
|
||||
extra_files = {}
|
||||
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
f = str(f)
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
torch._C._save_jit_module(m._c, f, extra_files)
|
||||
else:
|
||||
s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
|
||||
|
|
@ -259,7 +258,7 @@ def get_flatbuffer_module_info(path_or_file):
|
|||
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
|
||||
}
|
||||
"""
|
||||
if isinstance(path_or_file, (str, pathlib.Path)):
|
||||
if isinstance(path_or_file, (str, os.PathLike)):
|
||||
with open(path_or_file, "rb") as f:
|
||||
all_bytes = f.read()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import os
|
||||
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
|
||||
from torch.jit._serialization import validate_map_location
|
||||
|
|
@ -45,8 +43,8 @@ def _load_for_lite_interpreter(f, map_location=None):
|
|||
|
||||
map_location = validate_map_location(map_location)
|
||||
|
||||
if isinstance(f, (str, pathlib.Path)):
|
||||
cpp_module = torch._C._load_for_lite_interpreter(f, map_location)
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
|
||||
else:
|
||||
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
|
||||
f.read(), map_location
|
||||
|
|
@ -104,8 +102,8 @@ def _get_model_bytecode_version(f_input) -> int:
|
|||
if os.path.isdir(f_input):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, pathlib.Path)):
|
||||
return torch._C._get_model_bytecode_version(str(f_input))
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
return torch._C._get_model_bytecode_version(os.fspath(f_input))
|
||||
else:
|
||||
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
|
||||
|
||||
|
|
@ -136,8 +134,8 @@ def _get_mobile_model_contained_types(f_input) -> int:
|
|||
if os.path.isdir(f_input):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, pathlib.Path)):
|
||||
return torch._C._get_mobile_model_contained_types(str(f_input))
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
|
||||
else:
|
||||
return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
|
||||
|
||||
|
|
@ -159,10 +157,12 @@ def _backport_for_mobile(f_input, f_output, to_version):
|
|||
if os.path.isdir(f_input):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if (isinstance(f_input, (str, pathlib.Path))) and (
|
||||
isinstance(f_output, (str, pathlib.Path))
|
||||
if (isinstance(f_input, (str, os.PathLike))) and (
|
||||
isinstance(f_output, (str, os.PathLike))
|
||||
):
|
||||
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version)
|
||||
return torch._C._backport_for_mobile(
|
||||
os.fspath(f_input), os.fspath(f_output), to_version
|
||||
)
|
||||
else:
|
||||
return torch._C._backport_for_mobile_from_buffer(
|
||||
f_input.read(), str(f_output), to_version
|
||||
|
|
@ -183,8 +183,8 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
|
|||
if os.path.isdir(f_input):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, pathlib.Path)):
|
||||
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version)
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
|
||||
else:
|
||||
return torch._C._backport_for_mobile_from_buffer_to_buffer(
|
||||
f_input.read(), to_version
|
||||
|
|
@ -226,7 +226,7 @@ def _get_model_ops_and_info(f_input):
|
|||
if os.path.isdir(f_input):
|
||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||
|
||||
if isinstance(f_input, (str, pathlib.Path)):
|
||||
return torch._C._get_model_ops_and_info(str(f_input))
|
||||
if isinstance(f_input, (str, os.PathLike)):
|
||||
return torch._C._get_model_ops_and_info(os.fspath(f_input))
|
||||
else:
|
||||
return torch._C._get_model_ops_and_info(f_input.read())
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@ import importlib.machinery
|
|||
import inspect
|
||||
import io
|
||||
import linecache
|
||||
import os.path
|
||||
import os
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
|
|
@ -67,7 +66,7 @@ class PackageImporter(Importer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
|
||||
file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
|
||||
module_allowed: Callable[[str], bool] = lambda module_name: True,
|
||||
):
|
||||
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
|
||||
|
|
@ -89,8 +88,8 @@ class PackageImporter(Importer):
|
|||
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
|
||||
self.filename = "<pytorch_file_reader>"
|
||||
self.zip_reader = file_or_buffer
|
||||
elif isinstance(file_or_buffer, (Path, str)):
|
||||
self.filename = str(file_or_buffer)
|
||||
elif isinstance(file_or_buffer, (os.PathLike, str)):
|
||||
self.filename = os.fspath(file_or_buffer)
|
||||
if not os.path.isdir(self.filename):
|
||||
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user