mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] simplify implementation for os.fspath (#133801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133801 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
d01a7a9faa
commit
9e806c1a60
|
|
@ -168,15 +168,3 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
|
|||
if isinstance(obj, cls):
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
|
||||
|
||||
def fspath(path):
|
||||
# Python equivalent of os.fspath
|
||||
if isinstance(path, (str, bytes)):
|
||||
return path
|
||||
elif hasattr(path, "__fspath__"):
|
||||
return path.__fspath__()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"expected str, bytes or os.PathLike object, not {type(path).__name__}"
|
||||
)
|
||||
|
|
|
|||
6
torch/_dynamo/polyfills/builtins.py
Normal file
6
torch/_dynamo/polyfills/builtins.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""
|
||||
Python polyfills for builtins
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
6
torch/_dynamo/polyfills/functools.py
Normal file
6
torch/_dynamo/polyfills/functools.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""
|
||||
Python polyfills for functools
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
6
torch/_dynamo/polyfills/itertools.py
Normal file
6
torch/_dynamo/polyfills/itertools.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""
|
||||
Python polyfills for itertools
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
|
@ -11,7 +11,13 @@ if TYPE_CHECKING:
|
|||
from types import ModuleType
|
||||
|
||||
|
||||
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ()
|
||||
POLYFILLED_MODULE_NAMES: Tuple[str, ...] = (
|
||||
"builtins",
|
||||
"functools",
|
||||
"itertools",
|
||||
"os",
|
||||
"sys",
|
||||
)
|
||||
POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple(
|
||||
importlib.import_module(f".{submodule}", package=polyfills.__name__)
|
||||
for submodule in POLYFILLED_MODULE_NAMES
|
||||
|
|
|
|||
36
torch/_dynamo/polyfills/os.py
Normal file
36
torch/_dynamo/polyfills/os.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""
|
||||
Python polyfills for os
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import AnyStr
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
__all__ = ["fspath"]
|
||||
|
||||
|
||||
# Copied from os.py in the standard library
|
||||
@substitute_in_graph(os.fspath)
|
||||
def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr:
|
||||
if isinstance(path, (str, bytes)):
|
||||
return path
|
||||
|
||||
path_type = type(path)
|
||||
try:
|
||||
path_repr = path_type.__fspath__(path) # type: ignore[arg-type]
|
||||
except AttributeError:
|
||||
if hasattr(path_type, "__fspath__"):
|
||||
raise
|
||||
raise TypeError(
|
||||
f"expected str, bytes or os.PathLike object, not {path_type.__name__}",
|
||||
) from None
|
||||
if isinstance(path_repr, (str, bytes)):
|
||||
return path_repr # type: ignore[return-value]
|
||||
raise TypeError(
|
||||
f"expected {path_type.__name__}.__fspath__() to return str or bytes, "
|
||||
f"not {type(path_repr).__name__}",
|
||||
)
|
||||
6
torch/_dynamo/polyfills/sys.py
Normal file
6
torch/_dynamo/polyfills/sys.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""
|
||||
Python polyfills for sys
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
|
@ -4,7 +4,6 @@ import dataclasses
|
|||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
|
|
@ -15,7 +14,7 @@ import torch._C
|
|||
import torch._numpy as tnp
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
from .. import config, polyfills, variables
|
||||
from .. import config, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import unimplemented
|
||||
|
|
@ -30,7 +29,6 @@ from ..source import (
|
|||
)
|
||||
from ..utils import (
|
||||
check_unspec_or_constant_args,
|
||||
hashable,
|
||||
identity,
|
||||
is_tensor_base_attr_getter,
|
||||
proxy_args_kwargs,
|
||||
|
|
@ -49,10 +47,6 @@ from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObject
|
|||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS = {
|
||||
os.fspath: polyfills.fspath,
|
||||
}
|
||||
|
||||
|
||||
class NO_SUCH_SUBOBJ:
|
||||
pass
|
||||
|
|
@ -1167,12 +1161,6 @@ class PythonModuleVariable(VariableTracker):
|
|||
else:
|
||||
attr_value = self.value.__dict__[name]
|
||||
|
||||
if hashable(attr_value):
|
||||
if polyfill_fn := POLYFILL_SUPPORTED_PYTHON_MODULE_METHODS.get(
|
||||
attr_value, None
|
||||
):
|
||||
return variables.UserFunctionVariable(polyfill_fn)
|
||||
|
||||
if self.source:
|
||||
new_source = AttrSource(self.source, name)
|
||||
return VariableBuilder(tx, new_source)(attr_value)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user