[BE]: Add better handling of pathlib.Path with os calls (#116564)

Builds on #116562 to the rest of the instances of pathlib in the PyTorch.
* Uses more generic `os.PathLike` and `os.fspath` calls where appropiate
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116564
Approved by: https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2023-12-31 01:46:03 +00:00 committed by PyTorch MergeBot
parent 86cd6655a1
commit aef06c316b
5 changed files with 38 additions and 41 deletions

View File

@ -3,7 +3,6 @@ import dataclasses
import functools
import io
import json
import pathlib
import re
import sys
import os
@ -207,7 +206,7 @@ def export(
def save(
ep: ExportedProgram,
f: Union[str, pathlib.Path, io.BytesIO],
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
@ -216,8 +215,8 @@ def save(
from .serde.schema import SCHEMA_VERSION
artifact: SerializedArtifact = serialize(ep, opset_version)
if isinstance(f, (str, pathlib.Path)):
f = str(f)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with zipfile.ZipFile(f, 'w') as zipf:
# Save every field the SerializedArtifact to a file
@ -236,13 +235,13 @@ def save(
def load(
f: Union[str, pathlib.Path, io.BytesIO],
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
) -> ExportedProgram:
if isinstance(f, (str, pathlib.Path)):
f = str(f)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with zipfile.ZipFile(f, 'r') as zipf:
# Check the version

View File

@ -3,7 +3,7 @@ import copy
import dataclasses
import inspect
import io
import pathlib
import os
import sys
import typing
import warnings
@ -200,7 +200,7 @@ def export(
def save(
ep: ExportedProgram,
f: Union[str, pathlib.Path, io.BytesIO],
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
@ -217,7 +217,7 @@ def save(
Args:
ep (ExportedProgram): The exported program to save.
f (Union[str, pathlib.Path, io.BytesIO): A file-like object (has to
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
@ -256,7 +256,7 @@ def save(
def load(
f: Union[str, pathlib.Path, io.BytesIO],
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
@ -273,7 +273,7 @@ def load(
Args:
ep (ExportedProgram): The exported program to save.
f (Union[str, pathlib.Path, io.BytesIO): A file-like object (has to
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): The extra filenames given in

View File

@ -8,7 +8,6 @@ This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import os
import pathlib
import torch
from torch.jit._recursive import wrap_cpp_module
@ -76,7 +75,7 @@ def save(m, f, _extra_files=None):
"""
if _extra_files is None:
_extra_files = {}
if isinstance(f, (str, pathlib.Path)):
if isinstance(f, (str, os.PathLike)):
m.save(f, _extra_files=_extra_files)
else:
ret = m.save_to_buffer(_extra_files=_extra_files)
@ -155,8 +154,8 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
_extra_files = {}
cu = torch._C.CompilationUnit()
if isinstance(f, (str, pathlib.Path)):
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
else:
cpp_module = torch._C.import_ir_module_from_buffer(
cu, f.read(), map_location, _extra_files, _restore_shapes
@ -182,8 +181,8 @@ def validate_map_location(map_location=None):
def jit_module_from_flatbuffer(f):
if isinstance(f, (str, pathlib.Path)):
f = str(f)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
else:
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
@ -231,8 +230,8 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
if extra_files is None:
extra_files = {}
if isinstance(f, (str, pathlib.Path)):
f = str(f)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
torch._C._save_jit_module(m._c, f, extra_files)
else:
s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
@ -259,7 +258,7 @@ def get_flatbuffer_module_info(path_or_file):
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
}
"""
if isinstance(path_or_file, (str, pathlib.Path)):
if isinstance(path_or_file, (str, os.PathLike)):
with open(path_or_file, "rb") as f:
all_bytes = f.read()
else:

View File

@ -1,7 +1,5 @@
import os
import pathlib
import torch
from torch.jit._serialization import validate_map_location
@ -45,8 +43,8 @@ def _load_for_lite_interpreter(f, map_location=None):
map_location = validate_map_location(map_location)
if isinstance(f, (str, pathlib.Path)):
cpp_module = torch._C._load_for_lite_interpreter(f, map_location)
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
else:
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
f.read(), map_location
@ -104,8 +102,8 @@ def _get_model_bytecode_version(f_input) -> int:
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, pathlib.Path)):
return torch._C._get_model_bytecode_version(str(f_input))
if isinstance(f_input, (str, os.PathLike)):
return torch._C._get_model_bytecode_version(os.fspath(f_input))
else:
return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
@ -136,8 +134,8 @@ def _get_mobile_model_contained_types(f_input) -> int:
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, pathlib.Path)):
return torch._C._get_mobile_model_contained_types(str(f_input))
if isinstance(f_input, (str, os.PathLike)):
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
else:
return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
@ -159,10 +157,12 @@ def _backport_for_mobile(f_input, f_output, to_version):
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if (isinstance(f_input, (str, pathlib.Path))) and (
isinstance(f_output, (str, pathlib.Path))
if (isinstance(f_input, (str, os.PathLike))) and (
isinstance(f_output, (str, os.PathLike))
):
return torch._C._backport_for_mobile(str(f_input), str(f_output), to_version)
return torch._C._backport_for_mobile(
os.fspath(f_input), os.fspath(f_output), to_version
)
else:
return torch._C._backport_for_mobile_from_buffer(
f_input.read(), str(f_output), to_version
@ -183,8 +183,8 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, pathlib.Path)):
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version)
if isinstance(f_input, (str, os.PathLike)):
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
else:
return torch._C._backport_for_mobile_from_buffer_to_buffer(
f_input.read(), to_version
@ -226,7 +226,7 @@ def _get_model_ops_and_info(f_input):
if os.path.isdir(f_input):
raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, pathlib.Path)):
return torch._C._get_model_ops_and_info(str(f_input))
if isinstance(f_input, (str, os.PathLike)):
return torch._C._get_model_ops_and_info(os.fspath(f_input))
else:
return torch._C._get_model_ops_and_info(f_input.read())

View File

@ -4,10 +4,9 @@ import importlib.machinery
import inspect
import io
import linecache
import os.path
import os
import types
from contextlib import contextmanager
from pathlib import Path
from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
from weakref import WeakValueDictionary
@ -67,7 +66,7 @@ class PackageImporter(Importer):
def __init__(
self,
file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
module_allowed: Callable[[str], bool] = lambda module_name: True,
):
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
@ -89,8 +88,8 @@ class PackageImporter(Importer):
if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
self.filename = "<pytorch_file_reader>"
self.zip_reader = file_or_buffer
elif isinstance(file_or_buffer, (Path, str)):
self.filename = str(file_or_buffer)
elif isinstance(file_or_buffer, (os.PathLike, str)):
self.filename = os.fspath(file_or_buffer)
if not os.path.isdir(self.filename):
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
else: